فهرست منبع

Merge remote-tracking branch 'origin/master' into holepunch-remote-allow-list

Wade Simmons 3 ماه پیش
والد
کامیت
0d6b19ee8f
62فایلهای تغییر یافته به همراه1100 افزوده شده و 453 حذف شده
  1. 1 1
      .github/workflows/gofmt.yml
  2. 4 4
      .github/workflows/release.yml
  3. 1 1
      .github/workflows/smoke.yml
  4. 7 7
      .github/workflows/test.yml
  5. 20 6
      .golangci.yaml
  6. 7 0
      CHANGELOG.md
  7. 12 26
      allow_list.go
  8. 12 12
      allow_list_test.go
  9. 1 1
      cert/cert_v1.go
  10. 2 2
      cert/crypto_test.go
  11. 20 20
      cmd/nebula-cert/ca_test.go
  12. 10 10
      cmd/nebula-cert/keygen_test.go
  13. 1 1
      cmd/nebula-cert/main.go
  14. 8 8
      cmd/nebula-cert/print_test.go
  15. 16 16
      cmd/nebula-cert/verify_test.go
  16. 33 17
      config/config.go
  17. 19 19
      config/config_test.go
  18. 2 2
      control_test.go
  19. 8 8
      dns_server_test.go
  20. 3 3
      e2e/handshakes_test.go
  21. 2 2
      e2e/helpers_test.go
  22. 30 11
      examples/config.yml
  23. 5 5
      firewall.go
  24. 1 1
      firewall/packet.go
  25. 27 27
      firewall_test.go
  26. 8 9
      go.mod
  27. 12 14
      go.sum
  28. 23 2
      handshake_ix.go
  29. 1 1
      header/header.go
  30. 2 2
      hostmap_test.go
  31. 83 10
      inside.go
  32. 3 3
      lighthouse.go
  33. 15 15
      lighthouse_test.go
  34. 2 2
      main.go
  35. 3 1
      overlay/device.go
  36. 69 17
      overlay/route.go
  37. 147 41
      overlay/route_test.go
  38. 3 2
      overlay/tun_android.go
  39. 6 5
      overlay/tun_darwin.go
  40. 3 2
      overlay/tun_disabled.go
  41. 5 4
      overlay/tun_freebsd.go
  42. 3 2
      overlay/tun_ios.go
  43. 70 23
      overlay/tun_linux.go
  44. 5 4
      overlay/tun_netbsd.go
  45. 5 4
      overlay/tun_openbsd.go
  46. 3 2
      overlay/tun_tester.go
  47. 11 6
      overlay/tun_windows.go
  48. 8 3
      overlay/user.go
  49. 4 4
      punchy_test.go
  50. 39 0
      routing/balance.go
  51. 144 0
      routing/balance_test.go
  52. 70 0
      routing/gateway.go
  53. 34 0
      routing/gateway_test.go
  54. 2 2
      service/service_test.go
  55. 46 46
      ssh.go
  56. 5 5
      sshd/command.go
  57. 1 1
      sshd/server.go
  58. 5 5
      sshd/session.go
  59. 1 1
      test/assert.go
  60. 4 2
      test/tun.go
  61. 2 2
      util/error.go
  62. 1 1
      util/error_test.go

+ 1 - 1
.github/workflows/gofmt.yml

@@ -18,7 +18,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.23'
+        go-version: '1.24'
         check-latest: true
 
     - name: Install goimports

+ 4 - 4
.github/workflows/release.yml

@@ -14,7 +14,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.23'
+          go-version: '1.24'
           check-latest: true
 
       - name: Build
@@ -37,7 +37,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.23'
+          go-version: '1.24'
           check-latest: true
 
       - name: Build
@@ -70,12 +70,12 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.23'
+          go-version: '1.24'
           check-latest: true
 
       - name: Import certificates
         if: env.HAS_SIGNING_CREDS == 'true'
-        uses: Apple-Actions/import-codesign-certs@v3
+        uses: Apple-Actions/import-codesign-certs@v5
         with:
           p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
           p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}

+ 1 - 1
.github/workflows/smoke.yml

@@ -22,7 +22,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.23'
+        go-version: '1.24'
         check-latest: true
 
     - name: build

+ 7 - 7
.github/workflows/test.yml

@@ -22,7 +22,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.23'
+        go-version: '1.24'
         check-latest: true
 
     - name: Build
@@ -32,9 +32,9 @@ jobs:
       run: make vet
 
     - name: golangci-lint
-      uses: golangci/golangci-lint-action@v6
+      uses: golangci/golangci-lint-action@v7
       with:
-        version: v1.64
+        version: v2.0
 
     - name: Test
       run: make test
@@ -60,7 +60,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.23'
+        go-version: '1.24'
         check-latest: true
 
     - name: Build
@@ -102,7 +102,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.23'
+        go-version: '1.24'
         check-latest: true
 
     - name: Build nebula
@@ -115,9 +115,9 @@ jobs:
       run: make vet
 
     - name: golangci-lint
-      uses: golangci/golangci-lint-action@v6
+      uses: golangci/golangci-lint-action@v7
       with:
-        version: v1.64
+        version: v2.0
 
     - name: Test
       run: make test

+ 20 - 6
.golangci.yaml

@@ -1,9 +1,23 @@
-# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
+version: "2"
 linters:
-  # Disable all linters.
-  # Default: false
-  disable-all: true
-  # Enable specific linter
-  # https://golangci-lint.run/usage/linters/#enabled-by-default
+  default: none
   enable:
     - testifylint
+  exclusions:
+    generated: lax
+    presets:
+      - comments
+      - common-false-positives
+      - legacy
+      - std-error-handling
+    paths:
+      - third_party$
+      - builtin$
+      - examples$
+formatters:
+  exclusions:
+    generated: lax
+    paths:
+      - third_party$
+      - builtin$
+      - examples$

+ 7 - 0
CHANGELOG.md

@@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+### Changed
+
+- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
+  intended to target an `unsafe_routes` entry must explicitly declare it via the
+  `local_cidr` field. This is almost always the intended behavior. This flag is
+  deprecated and will be removed in a future release.
+
 ## [1.9.4] - 2024-09-09
 
 ### Added

+ 12 - 26
allow_list.go

@@ -36,7 +36,7 @@ type AllowListNameRule struct {
 
 func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
 	var nameRules []AllowListNameRule
-	handleKey := func(key string, value interface{}) (bool, error) {
+	handleKey := func(key string, value any) (bool, error) {
 		if key == "interfaces" {
 			var err error
 			nameRules, err = getAllowListInterfaces(k, value)
@@ -70,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
 
 // If the handleKey func returns true, the rest of the parsing is skipped
 // for this key. This allows parsing of special values like `interfaces`.
-func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
+func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
 	r := c.Get(k)
 	if r == nil {
 		return nil, nil
@@ -81,8 +81,8 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
 
 // If the handleKey func returns true, the rest of the parsing is skipped
 // for this key. This allows parsing of special values like `interfaces`.
-func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
-	rawMap, ok := raw.(map[interface{}]interface{})
+func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
+	rawMap, ok := raw.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
@@ -100,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 	rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
 	rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
 
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
+	for rawCIDR, rawValue := range rawMap {
 		if handleKey != nil {
 			handled, err := handleKey(rawCIDR, rawValue)
 			if err != nil {
@@ -116,7 +111,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 			}
 		}
 
-		value, ok := rawValue.(bool)
+		value, ok := config.AsBool(rawValue)
 		if !ok {
 			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
 		}
@@ -173,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 	return &AllowList{cidrTree: tree}, nil
 }
 
-func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
+func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
 	var nameRules []AllowListNameRule
 
-	rawRules, ok := v.(map[interface{}]interface{})
+	rawRules, ok := v.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
 	}
 
 	firstEntry := true
 	var allValues bool
-	for rawName, rawAllow := range rawRules {
-		name, ok := rawName.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
-		}
-		allow, ok := rawAllow.(bool)
+	for name, rawAllow := range rawRules {
+		allow, ok := config.AsBool(rawAllow)
 		if !ok {
 			return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
 		}
@@ -224,16 +215,11 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
 
 	remoteAllowRanges := new(bart.Table[*AllowList])
 
-	rawMap, ok := value.(map[interface{}]interface{})
+	rawMap, ok := value.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
 	}
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
+	for rawCIDR, rawValue := range rawMap {
 		allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
 		if err != nil {
 			return nil, err

+ 12 - 12
allow_list_test.go

@@ -15,27 +15,27 @@ import (
 func TestNewAllowListFromConfig(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"192.168.0.0": true,
 	}
 	r, err := newAllowListFromConfig(c, "allowlist", nil)
 	require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
 	assert.Nil(t, r)
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"192.168.0.0/16": "abc",
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
 	require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"192.168.0.0/16": true,
 		"10.0.0.0/8":     false,
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
 	require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"0.0.0.0/0":      true,
 		"10.0.0.0/8":     false,
 		"10.42.42.0/24":  true,
@@ -45,7 +45,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
 	require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"0.0.0.0/0":     true,
 		"10.0.0.0/8":    false,
 		"10.42.42.0/24": true,
@@ -55,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		assert.NotNil(t, r)
 	}
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"0.0.0.0/0":      true,
 		"10.0.0.0/8":     false,
 		"10.42.42.0/24":  true,
@@ -70,16 +70,16 @@ func TestNewAllowListFromConfig(t *testing.T) {
 
 	// Test interface names
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
+		"interfaces": map[string]any{
 			`docker.*`: "foo",
 		},
 	}
 	lr, err := NewLocalAllowListFromConfig(c, "allowlist")
 	require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
+		"interfaces": map[string]any{
 			`docker.*`: false,
 			`eth.*`:    true,
 		},
@@ -87,8 +87,8 @@ func TestNewAllowListFromConfig(t *testing.T) {
 	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
 	require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
+		"interfaces": map[string]any{
 			`docker.*`: false,
 		},
 	}

+ 1 - 1
cert/cert_v1.go

@@ -41,7 +41,7 @@ type detailsV1 struct {
 	curve Curve
 }
 
-type m map[string]interface{}
+type m = map[string]any
 
 func (c *certificateV1) Version() Version {
 	return Version1

+ 2 - 2
cert/crypto_test.go

@@ -10,14 +10,14 @@ import (
 
 func TestNewArgon2Parameters(t *testing.T) {
 	p := NewArgon2Parameters(64*1024, 4, 3)
-	assert.EqualValues(t, &Argon2Parameters{
+	assert.Equal(t, &Argon2Parameters{
 		version:     argon2.Version,
 		Memory:      64 * 1024,
 		Parallelism: 4,
 		Iterations:  3,
 	}, p)
 	p = NewArgon2Parameters(2*1024*1024, 2, 1)
-	assert.EqualValues(t, &Argon2Parameters{
+	assert.Equal(t, &Argon2Parameters{
 		version:     argon2.Version,
 		Memory:      2 * 1024 * 1024,
 		Parallelism: 2,

+ 20 - 20
cmd/nebula-cert/ca_test.go

@@ -90,26 +90,26 @@ func Test_ca(t *testing.T) {
 	assertHelpError(t, ca(
 		[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
 	), "-name is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// ipv4 only ips
 	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// ipv4 only subnets
 	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
 	require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
@@ -121,8 +121,8 @@ func Test_ca(t *testing.T) {
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
 	require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
@@ -135,8 +135,8 @@ func Test_ca(t *testing.T) {
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	require.NoError(t, ca(args, ob, eb, nopw))
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
@@ -158,7 +158,7 @@ func Test_ca(t *testing.T) {
 	assert.Empty(t, lCrt.UnsafeNetworks())
 	assert.Len(t, lCrt.PublicKey(), 32)
 	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
-	assert.Equal(t, "", lCrt.Issuer())
+	assert.Empty(t, lCrt.Issuer())
 	assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
 
 	// test encrypted key
@@ -169,7 +169,7 @@ func Test_ca(t *testing.T) {
 	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	require.NoError(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// read encrypted key file and verify default params
 	rb, _ = os.ReadFile(keyF.Name())
@@ -197,7 +197,7 @@ func Test_ca(t *testing.T) {
 	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	require.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// test when user fails to enter a password
 	os.Remove(keyF.Name())
@@ -207,7 +207,7 @@ func Test_ca(t *testing.T) {
 	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
@@ -222,8 +222,8 @@ func Test_ca(t *testing.T) {
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// test that we won't overwrite existing key file
 	os.Remove(keyF.Name())
@@ -231,8 +231,8 @@ func Test_ca(t *testing.T) {
 	eb.Reset()
 	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	os.Remove(keyF.Name())
 
 }

+ 10 - 10
cmd/nebula-cert/keygen_test.go

@@ -37,20 +37,20 @@ func Test_keygen(t *testing.T) {
 
 	// required args
 	assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
 	require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
@@ -62,8 +62,8 @@ func Test_keygen(t *testing.T) {
 	eb.Reset()
 	args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
 	require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp pub file
 	pubF, err := os.CreateTemp("", "test.pub")
@@ -75,8 +75,8 @@ func Test_keygen(t *testing.T) {
 	eb.Reset()
 	args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
 	require.NoError(t, keygen(args, ob, eb))
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())

+ 1 - 1
cmd/nebula-cert/main.go

@@ -17,7 +17,7 @@ func (he *helpError) Error() string {
 	return he.s
 }
 
-func newHelpErrorf(s string, v ...interface{}) error {
+func newHelpErrorf(s string, v ...any) error {
 	return &helpError{s: fmt.Sprintf(s, v...)}
 }
 

+ 8 - 8
cmd/nebula-cert/print_test.go

@@ -43,16 +43,16 @@ func Test_printCert(t *testing.T) {
 
 	// no path
 	err := printCert([]string{}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	assertHelpError(t, err, "-path is required")
 
 	// no cert at path
 	ob.Reset()
 	eb.Reset()
 	err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
 
 	// invalid cert at path
@@ -64,8 +64,8 @@ func Test_printCert(t *testing.T) {
 
 	tf.WriteString("-----BEGIN NOPE-----")
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
 
 	// test multiple certs
@@ -155,7 +155,7 @@ func Test_printCert(t *testing.T) {
 `,
 		ob.String(),
 	)
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// test json
 	ob.Reset()
@@ -177,7 +177,7 @@ func Test_printCert(t *testing.T) {
 `,
 		ob.String(),
 	)
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 }
 
 // NewTestCaCert will generate a CA cert

+ 16 - 16
cmd/nebula-cert/verify_test.go

@@ -38,19 +38,19 @@ func Test_verify(t *testing.T) {
 
 	// required args
 	assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// no ca at path
 	ob.Reset()
 	eb.Reset()
 	err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
 
 	// invalid ca at path
@@ -62,8 +62,8 @@ func Test_verify(t *testing.T) {
 
 	caFile.WriteString("-----BEGIN NOPE-----")
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
 
 	// make a ca for later
@@ -76,8 +76,8 @@ func Test_verify(t *testing.T) {
 
 	// no crt at path
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 	// invalid crt at path
@@ -89,8 +89,8 @@ func Test_verify(t *testing.T) {
 
 	certFile.WriteString("-----BEGIN NOPE-----")
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
 
 	// unverifiable cert at path
@@ -106,8 +106,8 @@ func Test_verify(t *testing.T) {
 	certFile.Write(b)
 
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.ErrorIs(t, err, cert.ErrSignatureMismatch)
 
 	// verified cert at path
@@ -118,7 +118,7 @@ func Test_verify(t *testing.T) {
 	certFile.Write(b)
 
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	require.NoError(t, err)
 }

+ 33 - 17
config/config.go

@@ -17,14 +17,14 @@ import (
 
 	"dario.cat/mergo"
 	"github.com/sirupsen/logrus"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
 type C struct {
 	path        string
 	files       []string
-	Settings    map[interface{}]interface{}
-	oldSettings map[interface{}]interface{}
+	Settings    map[string]any
+	oldSettings map[string]any
 	callbacks   []func(*C)
 	l           *logrus.Logger
 	reloadLock  sync.Mutex
@@ -32,7 +32,7 @@ type C struct {
 
 func NewC(l *logrus.Logger) *C {
 	return &C{
-		Settings: make(map[interface{}]interface{}),
+		Settings: make(map[string]any),
 		l:        l,
 	}
 }
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
 	}
 
 	var (
-		nv interface{}
-		ov interface{}
+		nv any
+		ov any
 	)
 
 	if k == "" {
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
 	c.reloadLock.Lock()
 	defer c.reloadLock.Unlock()
 
-	c.oldSettings = make(map[interface{}]interface{})
+	c.oldSettings = make(map[string]any)
 	for k, v := range c.Settings {
 		c.oldSettings[k] = v
 	}
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
 	c.reloadLock.Lock()
 	defer c.reloadLock.Unlock()
 
-	c.oldSettings = make(map[interface{}]interface{})
+	c.oldSettings = make(map[string]any)
 	for k, v := range c.Settings {
 		c.oldSettings[k] = v
 	}
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
 		return d
 	}
 
-	rv, ok := r.([]interface{})
+	rv, ok := r.([]any)
 	if !ok {
 		return d
 	}
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
 }
 
 // GetMap will get the map for k or return the default d if not found or invalid
-func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
+func (c *C) GetMap(k string, d map[string]any) map[string]any {
 	r := c.Get(k)
 	if r == nil {
 		return d
 	}
 
-	v, ok := r.(map[interface{}]interface{})
+	v, ok := r.(map[string]any)
 	if !ok {
 		return d
 	}
@@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool {
 	return v
 }
 
+func AsBool(v any) (value bool, ok bool) {
+	switch x := v.(type) {
+	case bool:
+		return x, true
+	case string:
+		switch x {
+		case "y", "yes":
+			return true, true
+		case "n", "no":
+			return false, true
+		}
+	}
+
+	return false, false
+}
+
 // GetDuration will get the duration for k or return the default d if not found or invalid
 func (c *C) GetDuration(k string, d time.Duration) time.Duration {
 	r := c.GetString(k, "")
@@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
 	return v
 }
 
-func (c *C) Get(k string) interface{} {
+func (c *C) Get(k string) any {
 	return c.get(k, c.Settings)
 }
 
@@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool {
 	return c.get(k, c.Settings) != nil
 }
 
-func (c *C) get(k string, v interface{}) interface{} {
+func (c *C) get(k string, v any) any {
 	parts := strings.Split(k, ".")
 	for _, p := range parts {
-		m, ok := v.(map[interface{}]interface{})
+		m, ok := v.(map[string]any)
 		if !ok {
 			return nil
 		}
@@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error {
 }
 
 func (c *C) parseRaw(b []byte) error {
-	var m map[interface{}]interface{}
+	var m map[string]any
 
 	err := yaml.Unmarshal(b, &m)
 	if err != nil {
@@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error {
 }
 
 func (c *C) parse() error {
-	var m map[interface{}]interface{}
+	var m map[string]any
 
 	for _, path := range c.files {
 		b, err := os.ReadFile(path)
@@ -366,7 +382,7 @@ func (c *C) parse() error {
 			return err
 		}
 
-		var nm map[interface{}]interface{}
+		var nm map[string]any
 		err = yaml.Unmarshal(b, &nm)
 		if err != nil {
 			return err

+ 19 - 19
config/config_test.go

@@ -10,7 +10,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
 func TestConfig_Load(t *testing.T) {
@@ -19,7 +19,7 @@ func TestConfig_Load(t *testing.T) {
 	// invalid yaml
 	c := NewC(l)
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
-	require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
+	require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
 
 	// simple multi config merge
 	c = NewC(l)
@@ -31,8 +31,8 @@ func TestConfig_Load(t *testing.T) {
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 	os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
 	require.NoError(t, c.Load(dir))
-	expected := map[interface{}]interface{}{
-		"outer": map[interface{}]interface{}{
+	expected := map[string]any{
+		"outer": map[string]any{
 			"inner": "override",
 		},
 		"new": "hi",
@@ -44,12 +44,12 @@ func TestConfig_Get(t *testing.T) {
 	l := test.NewLogger()
 	// test simple type
 	c := NewC(l)
-	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
+	c.Settings["firewall"] = map[string]any{"outbound": "hi"}
 	assert.Equal(t, "hi", c.Get("firewall.outbound"))
 
 	// test complex type
-	inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
-	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
+	inner := []map[string]any{{"port": "1", "code": "2"}}
+	c.Settings["firewall"] = map[string]any{"outbound": inner}
 	assert.EqualValues(t, inner, c.Get("firewall.outbound"))
 
 	// test missing
@@ -59,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
 func TestConfig_GetStringSlice(t *testing.T) {
 	l := test.NewLogger()
 	c := NewC(l)
-	c.Settings["slice"] = []interface{}{"one", "two"}
+	c.Settings["slice"] = []any{"one", "two"}
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 }
 
@@ -101,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
 	// Test key change
 	c = NewC(l)
 	c.Settings["test"] = "hi"
-	c.oldSettings = map[interface{}]interface{}{"test": "no"}
+	c.oldSettings = map[string]any{"test": "no"}
 	assert.True(t, c.HasChanged("test"))
 	assert.True(t, c.HasChanged(""))
 
 	// No key change
 	c = NewC(l)
 	c.Settings["test"] = "hi"
-	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
+	c.oldSettings = map[string]any{"test": "hi"}
 	assert.False(t, c.HasChanged("test"))
 	assert.False(t, c.HasChanged(""))
 }
@@ -184,11 +184,11 @@ firewall:
 `),
 	}
 
-	var m map[any]any
+	var m map[string]any
 
 	// merge the same way config.parse() merges
 	for _, b := range configs {
-		var nm map[any]any
+		var nm map[string]any
 		err := yaml.Unmarshal(b, &nm)
 		require.NoError(t, err)
 
@@ -205,15 +205,15 @@ firewall:
 	t.Logf("Merged Config as YAML:\n%s", mYaml)
 
 	// If a bug is present, some items might be replaced instead of merged like we expect
-	expected := map[any]any{
-		"firewall": map[any]any{
+	expected := map[string]any{
+		"firewall": map[string]any{
 			"inbound": []any{
-				map[any]any{"host": "any", "port": "any", "proto": "icmp"},
-				map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
-				map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
+				map[string]any{"host": "any", "port": "any", "proto": "icmp"},
+				map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
+				map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
 			"outbound": []any{
-				map[any]any{"host": "any", "port": "any", "proto": "any"}}},
-		"listen": map[any]any{
+				map[string]any{"host": "any", "port": "any", "proto": "any"}}},
+		"listen": map[string]any{
 			"host": "0.0.0.0",
 			"port": 4242,
 		},

+ 2 - 2
control_test.go

@@ -101,7 +101,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 
 	// Make sure we don't have any unexpected fields
 	assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
-	assert.EqualValues(t, &expectedInfo, thi)
+	assert.Equal(t, &expectedInfo, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
@@ -110,7 +110,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	})
 }
 
-func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
+func assertFields(t *testing.T, expected []string, actualStruct any) {
 	val := reflect.ValueOf(actualStruct).Elem()
 	fields := make([]string, val.NumField())
 	for i := 0; i < val.NumField(); i++ {

+ 8 - 8
dns_server_test.go

@@ -38,24 +38,24 @@ func TestParsequery(t *testing.T) {
 func Test_getDnsServerAddr(t *testing.T) {
 	c := config.NewC(nil)
 
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "0.0.0.0",
 			"port": "1",
 		},
 	}
 	assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
 
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "::",
 			"port": "1",
 		},
 	}
 	assert.Equal(t, "[::]:1", getDnsServerAddr(c))
 
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "[::]",
 			"port": "1",
 		},
@@ -63,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
 	assert.Equal(t, "[::]:1", getDnsServerAddr(c))
 
 	// Make sure whitespace doesn't mess us up
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "[::] ",
 			"port": "1",
 		},

+ 3 - 3
e2e/handshakes_test.go

@@ -20,7 +20,7 @@ import (
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
 func BenchmarkHotPath(b *testing.B) {
@@ -991,7 +991,7 @@ func TestRehandshaking(t *testing.T) {
 	require.NoError(t, err)
 	var theirNewConfig m
 	require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
-	theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
+	theirFirewall := theirNewConfig["firewall"].(map[string]any)
 	theirFirewall["inbound"] = []m{{
 		"proto": "any",
 		"port":  "any",
@@ -1087,7 +1087,7 @@ func TestRehandshakingLoser(t *testing.T) {
 	require.NoError(t, err)
 	var myNewConfig m
 	require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
-	theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
+	theirFirewall := myNewConfig["firewall"].(map[string]any)
 	theirFirewall["inbound"] = []m{{
 		"proto": "any",
 		"port":  "any",

+ 2 - 2
e2e/helpers_test.go

@@ -22,10 +22,10 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/stretchr/testify/assert"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 // newSimpleServer creates a nebula instance with many assumptions
 func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {

+ 30 - 11
examples/config.yml

@@ -239,7 +239,28 @@ tun:
 
   # Unsafe routes allows you to route traffic over nebula to non-nebula nodes
   # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
-  # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
+  # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
+  # NOTES:
+  # * You will only see a single gateway in the routing table if you are not on linux
+  # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
+  #
+  # unsafe_routes:
+  # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
+  # - route: 192.168.87.0/24
+  #   via:
+  #     - gateway: 10.0.0.1
+  #     - gateway: 10.0.0.2
+  #     - gateway: 10.0.0.3
+  # # Multiple gateways with a weight, this will balance traffic accordingly
+  # - route: 192.168.87.0/24
+  #   via:
+  #     - gateway: 10.0.0.1
+  #       weight: 10
+  #     - gateway: 10.0.0.2
+  #       weight: 5
+  #
+  # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
+  # `via`: single node or list of gateways to use for this route
   # `mtu`: will default to tun mtu if this option is not specified
   # `metric`: will default to 0 if this option is not specified
   # `install`: will default to true, controls whether this route is installed in the systems routing table.
@@ -325,11 +346,11 @@ firewall:
   outbound_action: drop
   inbound_action: drop
 
-  # Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
-  # This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
-  # unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
-  # of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
-  # if the intention is to allow traffic to flow to an unsafe route.
+  # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
+  # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
+  # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
+  # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
+  # is explicitly defined. This is usually not the desired behavior and should be avoided!
   #default_local_cidr_any: false
 
   conntrack:
@@ -347,11 +368,9 @@ firewall:
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
   #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
-  #     If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
-  #     Otherwise the default is any vpn network assigned to via the certificate.
-  #     `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
-  #     If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
+  #     By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
+  #     If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
 

+ 5 - 5
firewall.go

@@ -331,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 		return nil
 	}
 
-	rs, ok := r.([]interface{})
+	rs, ok := r.([]any)
 	if !ok {
 		return fmt.Errorf("%s failed to parse, should be an array of rules", table)
 	}
@@ -918,15 +918,15 @@ type rule struct {
 	CASha     string
 }
 
-func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
+func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
 	r := rule{}
 
-	m, ok := p.(map[interface{}]interface{})
+	m, ok := p.(map[string]any)
 	if !ok {
 		return r, errors.New("could not parse rule")
 	}
 
-	toString := func(k string, m map[interface{}]interface{}) string {
+	toString := func(k string, m map[string]any) string {
 		v, ok := m[k]
 		if !ok {
 			return ""
@@ -944,7 +944,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
 	r.CASha = toString("ca_sha", m)
 
 	// Make sure group isn't an array
-	if v, ok := m["group"].([]interface{}); ok {
+	if v, ok := m["group"].([]any); ok {
 		if len(v) > 1 {
 			return r, errors.New("group should contain a single value, an array with more than one entry was provided")
 		}

+ 1 - 1
firewall/packet.go

@@ -6,7 +6,7 @@ import (
 	"net/netip"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 const (
 	ProtoAny    = 0 // When we want to handle HOPOPT (0) we can change this, if ever

+ 27 - 27
firewall_test.go

@@ -631,53 +631,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	require.NoError(t, err)
 
 	conf := config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
+	conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 	// Test both port and code
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 	// Test proto error
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 	// Test cidr parse error
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test local_cidr parse error
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test both group and groups
 	conf = config.NewC(l)
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	_, err = NewFirewallFromConfig(l, cs, conf)
 	require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
@@ -687,28 +687,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	// Test adding tcp rule
 	conf := config.NewC(l)
 	mf := &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding udp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding any rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
@@ -716,49 +716,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	cidr := netip.MustParsePrefix("10.0.0.0/8")
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding rule with local_cidr
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
 
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
 
 	// Test single group
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test single groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
@@ -766,7 +766,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf.nextCallReturn = errors.New("test error")
-	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
 	require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
 }
 
@@ -776,8 +776,8 @@ func TestFirewall_convertRule(t *testing.T) {
 	l.SetOutput(ob)
 
 	// Ensure group array of 1 is converted and a warning is printed
-	c := map[interface{}]interface{}{
-		"group": []interface{}{"group1"},
+	c := map[string]any{
+		"group": []any{"group1"},
 	}
 
 	r, err := convertRule(l, c, "test", 1)
@@ -787,17 +787,17 @@ func TestFirewall_convertRule(t *testing.T) {
 
 	// Ensure group array of > 1 is errord
 	ob.Reset()
-	c = map[interface{}]interface{}{
-		"group": []interface{}{"group1", "group2"},
+	c = map[string]any{
+		"group": []any{"group1", "group2"},
 	}
 
 	r, err = convertRule(l, c, "test", 1)
-	assert.Equal(t, "", ob.String())
+	assert.Empty(t, ob.String())
 	require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
 
 	// Make sure a well formed group is alright
 	ob.Reset()
-	c = map[interface{}]interface{}{
+	c = map[string]any{
 		"group": "group1",
 	}
 

+ 8 - 9
go.mod

@@ -2,7 +2,7 @@ module github.com/slackhq/nebula
 
 go 1.23.6
 
-toolchain go1.23.7
+toolchain go1.24.1
 
 require (
 	dario.cat/mergo v1.0.1
@@ -10,11 +10,11 @@ require (
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
-	github.com/gaissmai/bart v0.18.1
+	github.com/gaissmai/bart v0.20.1
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.63
+	github.com/miekg/dns v1.1.64
 	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
 	github.com/prometheus/client_golang v1.21.1
@@ -26,15 +26,15 @@ require (
 	github.com/vishvananda/netlink v1.3.0
 	golang.org/x/crypto v0.36.0
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
-	golang.org/x/net v0.37.0
+	golang.org/x/net v0.38.0
 	golang.org/x/sync v0.12.0
 	golang.org/x/sys v0.31.0
 	golang.org/x/term v0.30.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.36.5
-	gopkg.in/yaml.v2 v2.4.0
+	google.golang.org/protobuf v1.36.6
+	gopkg.in/yaml.v3 v3.0.1
 	gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
 )
 
@@ -50,8 +50,7 @@ require (
 	github.com/prometheus/common v0.62.0 // indirect
 	github.com/prometheus/procfs v0.15.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
-	golang.org/x/mod v0.18.0 // indirect
+	golang.org/x/mod v0.23.0 // indirect
 	golang.org/x/time v0.5.0 // indirect
-	golang.org/x/tools v0.22.0 // indirect
-	gopkg.in/yaml.v3 v3.0.1 // indirect
+	golang.org/x/tools v0.30.0 // indirect
 )

+ 12 - 14
go.sum

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
-github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
-github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
+github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo=
+github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY=
-github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs=
+github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ=
+github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -164,8 +164,8 @@ golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPI
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
-golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
+golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -176,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
-golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
+golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
+golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -219,8 +219,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
-golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
+golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
+golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -239,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
-google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
+google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
+google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -251,8 +251,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
-gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 23 - 2
handshake_ix.go

@@ -71,7 +71,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 
 	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("certVersion", v).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 	}
@@ -185,6 +186,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	var vpnAddrs []netip.Addr
 	var filteredNetworks []netip.Prefix
 	certName := remoteCert.Certificate.Name()
+	certVersion := remoteCert.Certificate.Version()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
 
@@ -194,6 +196,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		if found {
 			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
@@ -212,6 +215,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if len(vpnAddrs) == 0 {
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
@@ -231,6 +235,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if err != nil {
 		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@@ -253,6 +258,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
+		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -264,6 +270,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if hs.Details.Cert == nil {
 		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -281,6 +288,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if err != nil {
 		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@@ -292,6 +300,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if err != nil {
 		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
@@ -299,6 +308,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	} else if dKey == nil || eKey == nil {
 		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@@ -366,6 +376,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
 			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
 				WithField("fingerprint", fingerprint).
@@ -381,6 +392,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
 			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -393,6 +405,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			// And we forget to update it here
 			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -409,6 +422,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		if err != nil {
 			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -417,6 +431,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		} else {
 			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -435,6 +450,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
 		f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -539,6 +555,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	vpnNetworks := remoteCert.Certificate.Networks()
 	certName := remoteCert.Certificate.Name()
+	certVersion := remoteCert.Certificate.Version()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
 
@@ -573,6 +590,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	if len(vpnAddrs) == 0 {
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
@@ -582,7 +600,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	// Ensure the right host responded
 	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
 		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
-			WithField("udpAddr", addr).WithField("certName", certName).
+			WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 
@@ -618,6 +638,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	duration := time.Since(hh.startTime).Nanoseconds()
 	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
+		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).

+ 1 - 1
header/header.go

@@ -19,7 +19,7 @@ import (
 // |-----------------------------------------------------------------------|
 // |                               payload...                              |
 
-type m map[string]interface{}
+type m = map[string]any
 
 const (
 	Version uint8 = 1

+ 2 - 2
hostmap_test.go

@@ -210,8 +210,8 @@ func TestHostMap_reload(t *testing.T) {
 	assert.Empty(t, hm.GetPreferredRanges())
 
 	c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
-	assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
+	assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
 
 	c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
-	assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
+	assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
 }

+ 83 - 10
inside.go

@@ -8,6 +8,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
+	"github.com/slackhq/nebula/routing"
 )
 
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
@@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 
+// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
 func (f *Interface) Handshake(vpnAddr netip.Addr) {
-	f.getOrHandshake(vpnAddr, nil)
+	f.getOrHandshakeNoRouting(vpnAddr, nil)
 }
 
-// getOrHandshake returns nil if the vpnAddr is not routable.
+// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
-	if !found {
-		vpnAddr = f.inside.RouteFor(vpnAddr)
-		if !vpnAddr.IsValid() {
-			return nil, false
+	if found {
+		return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
+	}
+
+	return nil, false
+}
+
+// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
+// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
+func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+
+	destinationAddr := fwPacket.RemoteAddr
+
+	hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
+
+	// Host is inside the mesh, no routing required
+	if hostinfo != nil {
+		return hostinfo, ready
+	}
+
+	gateways := f.inside.RoutesFor(destinationAddr)
+
+	switch len(gateways) {
+	case 0:
+		return nil, false
+	case 1:
+		// Single gateway route
+		return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
+	default:
+		// Multi gateway route, perform ECMP categorization
+		gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
+
+		if !balancingOk {
+			// This happens if the gateway buckets were not calculated, this _should_ never happen
+			f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
 		}
+
+		var handshakeInfoForChosenGateway *HandshakeHostInfo
+		var hhReceiver = func(hh *HandshakeHostInfo) {
+			handshakeInfoForChosenGateway = hh
+		}
+
+		// Store the handshakeHostInfo for later.
+		// If this node is not reachable we will attempt other nodes, if none are reachable we will
+		// cache the packet for this gateway.
+		if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
+			return hostinfo, true
+		}
+
+		// It appears the selected gateway cannot be reached, find another gateway to fallback on.
+		// The current implementation breaks ECMP but that seems better than no connectivity.
+		// If ECMP is also required when a gateway is down then connectivity status
+		// for each gateway needs to be kept and the weights recalculated when they go up or down.
+		// This would also need to interact with unsafe_route updates through reloading the config or
+		// use of the use_system_route_table option
+
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("destination", destinationAddr).
+				WithField("originalGateway", gatewayAddr).
+				Debugln("Calculated gateway for ECMP not available, attempting other gateways")
+		}
+
+		for i := range gateways {
+			// Skip the gateway that failed previously
+			if gateways[i].Addr() == gatewayAddr {
+				continue
+			}
+
+			// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
+			if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
+				return hostinfo, true
+			}
+		}
+
+		// No gateways reachable, cache the packet in the originally chosen gateway
+		cacheCallback(handshakeInfoForChosenGateway)
+		return hostinfo, false
 	}
 
-	return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 
 // SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
 func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
+	hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 

+ 3 - 3
lighthouse.go

@@ -422,7 +422,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
 		return err
 	}
 
-	shm := c.GetMap("static_host_map", map[interface{}]interface{}{})
+	shm := c.GetMap("static_host_map", map[string]any{})
 	i := 0
 
 	for k, v := range shm {
@@ -436,9 +436,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
 			return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
 		}
 
-		vals, ok := v.([]interface{})
+		vals, ok := v.([]any)
 		if !ok {
-			vals = []interface{}{v}
+			vals = []any{v}
 		}
 		remoteAddrs := []string{}
 		for _, v := range vals {

+ 15 - 15
lighthouse_test.go

@@ -14,7 +14,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
 func TestOldIPv4Only(t *testing.T) {
@@ -40,15 +40,15 @@ func Test_lhStaticMapping(t *testing.T) {
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
-	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
+	c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
+	c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
 	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	require.NoError(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
-	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
+	c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
+	c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
 	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
@@ -65,12 +65,12 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"hosts":    []interface{}{lh1},
+	c.Settings["lighthouse"] = map[string]any{
+		"hosts":    []any{lh1},
 		"interval": "1s",
 	}
 
-	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
+	c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	require.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
@@ -192,8 +192,8 @@ func TestLighthouse_Memory(t *testing.T) {
 	theirVpnIp := netip.MustParseAddr("10.128.0.3")
 
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
-	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
+	c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
+	c.Settings["listen"] = map[string]any{"port": 4242}
 
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
 	nt := new(bart.Table[struct{}])
@@ -277,8 +277,8 @@ func TestLighthouse_Memory(t *testing.T) {
 func TestLighthouse_reload(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
-	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
+	c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
+	c.Settings["listen"] = map[string]any{"port": 4242}
 
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
 	nt := new(bart.Table[struct{}])
@@ -291,9 +291,9 @@ func TestLighthouse_reload(t *testing.T) {
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	require.NoError(t, err)
 
-	nc := map[interface{}]interface{}{
-		"static_host_map": map[interface{}]interface{}{
-			"10.128.0.2": []interface{}{"1.1.1.1:4242"},
+	nc := map[string]any{
+		"static_host_map": map[string]any{
+			"10.128.0.2": []any{"1.1.1.1:4242"},
 		},
 	}
 	rc, err := yaml.Marshal(nc)

+ 2 - 2
main.go

@@ -13,10 +13,10 @@ import (
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())

+ 3 - 1
overlay/device.go

@@ -3,6 +3,8 @@ package overlay
 import (
 	"io"
 	"net/netip"
+
+	"github.com/slackhq/nebula/routing"
 )
 
 type Device interface {
@@ -10,6 +12,6 @@ type Device interface {
 	Activate() error
 	Networks() []netip.Prefix
 	Name() string
-	RouteFor(netip.Addr) netip.Addr
+	RoutesFor(netip.Addr) routing.Gateways
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }

+ 69 - 17
overlay/route.go

@@ -11,13 +11,14 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 )
 
 type Route struct {
 	MTU     int
 	Metric  int
 	Cidr    netip.Prefix
-	Via     netip.Addr
+	Via     routing.Gateways
 	Install bool
 }
 
@@ -47,15 +48,17 @@ func (r Route) String() string {
 	return s
 }
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
-	routeTree := new(bart.Table[netip.Addr])
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
+	routeTree := new(bart.Table[routing.Gateways])
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
 		}
 
-		if r.Via.IsValid() {
-			routeTree.Insert(r.Cidr, r.Via)
+		gateways := r.Via
+		if len(gateways) > 0 {
+			routing.CalculateBucketsForGateways(gateways)
+			routeTree.Insert(r.Cidr, gateways)
 		}
 	}
 	return routeTree, nil
@@ -69,7 +72,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 		return []Route{}, nil
 	}
 
-	rawRoutes, ok := r.([]interface{})
+	rawRoutes, ok := r.([]any)
 	if !ok {
 		return nil, fmt.Errorf("tun.routes is not an array")
 	}
@@ -80,7 +83,7 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 
 	routes := make([]Route, len(rawRoutes))
 	for i, r := range rawRoutes {
-		m, ok := r.(map[interface{}]interface{})
+		m, ok := r.(map[string]any)
 		if !ok {
 			return nil, fmt.Errorf("entry %v in tun.routes is invalid", i+1)
 		}
@@ -148,7 +151,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 		return []Route{}, nil
 	}
 
-	rawRoutes, ok := r.([]interface{})
+	rawRoutes, ok := r.([]any)
 	if !ok {
 		return nil, fmt.Errorf("tun.unsafe_routes is not an array")
 	}
@@ -159,7 +162,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 
 	routes := make([]Route, len(rawRoutes))
 	for i, r := range rawRoutes {
-		m, ok := r.(map[interface{}]interface{})
+		m, ok := r.(map[string]any)
 		if !ok {
 			return nil, fmt.Errorf("entry %v in tun.unsafe_routes is invalid", i+1)
 		}
@@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
 		}
 
-		via, ok := rVia.(string)
-		if !ok {
-			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
-		}
+		var gateways routing.Gateways
 
-		viaVpnIp, err := netip.ParseAddr(via)
-		if err != nil {
-			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
+		switch via := rVia.(type) {
+		case string:
+			viaIp, err := netip.ParseAddr(via)
+			if err != nil {
+				return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
+			}
+
+			gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
+
+		case []any:
+			gateways = make(routing.Gateways, len(via))
+			for ig, v := range via {
+				gatewayMap, ok := v.(map[string]any)
+				if !ok {
+					return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
+				}
+
+				rGateway, ok := gatewayMap["gateway"]
+				if !ok {
+					return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
+				}
+
+				parsedGateway, ok := rGateway.(string)
+				if !ok {
+					return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
+				}
+
+				gatewayIp, err := netip.ParseAddr(parsedGateway)
+				if err != nil {
+					return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
+				}
+
+				rGatewayWeight, ok := gatewayMap["weight"]
+				if !ok {
+					rGatewayWeight = 1
+				}
+
+				gatewayWeight, ok := rGatewayWeight.(int)
+				if !ok {
+					_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
+					if err != nil {
+						return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
+					}
+				}
+
+				if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
+					return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
+				}
+
+				gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
+
+			}
+
+		default:
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
 		}
 
 		rRoute, ok := m["route"]
@@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 		}
 
 		r := Route{
-			Via:     viaVpnIp,
+			Via:     gateways,
 			MTU:     mtu,
 			Metric:  metric,
 			Install: install,

+ 147 - 41
overlay/route_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
@@ -23,75 +24,75 @@ func Test_parseRoutes(t *testing.T) {
 	assert.Empty(t, routes)
 
 	// not an array
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
+	c.Settings["tun"] = map[string]any{"routes": "hi"}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "tun.routes is not an array")
 
 	// no routes
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
+	c.Settings["tun"] = map[string]any{"routes": []any{}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	require.NoError(t, err)
 	assert.Empty(t, routes)
 
 	// weird route
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
+	c.Settings["tun"] = map[string]any{"routes": []any{"asdf"}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 	// no mtu
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 	// bad mtu
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "nope"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "499"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 	// missing route
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 	// unparsable route
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "nope"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// below network range
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "1.0.0.0/8"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
 
 	// above network range
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "10.0.1.0/24"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
 
 	// Not in multiple ranges
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
+	c.Settings["tun"] = map[string]any{"routes": []any{map[string]any{"mtu": "500", "route": "192.0.0.0/24"}}}
 	routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
 
 	// happy case
-	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
-		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
-		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
+	c.Settings["tun"] = map[string]any{"routes": []any{
+		map[string]any{"mtu": "9000", "route": "10.0.0.0/29"},
+		map[string]any{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
 	routes, err = parseRoutes(c, []netip.Prefix{n})
 	require.NoError(t, err)
@@ -128,105 +129,129 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	assert.Empty(t, routes)
 
 	// not an array
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": "hi"}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 	// no routes
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	require.NoError(t, err)
 	assert.Empty(t, routes)
 
 	// weird route
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{"asdf"}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 	// no via
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
 	// invalid via
-	for _, invalidValue := range []interface{}{
+	for _, invalidValue := range []any{
 		127, false, nil, 1.0, []string{"1", "2"},
 	} {
-		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
+		c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": invalidValue}}}
 		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
-		require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
+		require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
 	}
 
+	// Unparsable list of via
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": []string{"1", "2"}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
+
 	// unparsable via
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
+	// unparsable gateway
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "1"}}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
+
+	// missing gateway element
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"weight": "1"}}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
+
+	// unparsable weight element
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"mtu": "500", "via": []any{map[string]any{"gateway": "10.0.0.1", "weight": "a"}}}}}
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
+	assert.Nil(t, routes)
+	require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
+
 	// missing route
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 	// unparsable route
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// within network range
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 	// below network range
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	require.NoError(t, err)
 
 	// above network range
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	require.NoError(t, err)
 
 	// no mtu
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Equal(t, 0, routes[0].MTU)
 
 	// bad mtu
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "499"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 	// bad install
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 
 	// happy case
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
-		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
-		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
-		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
-		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
+		map[string]any{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "t"},
+		map[string]any{"via": "127.0.0.1", "mtu": "8000", "route": "1.0.0.1/32", "install": 0},
+		map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
+		map[string]any{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
 	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	require.NoError(t, err)
@@ -263,9 +288,9 @@ func Test_makeRouteTree(t *testing.T) {
 	n, err := netip.ParsePrefix("10.0.0.0/24")
 	require.NoError(t, err)
 
-	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
-		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
-		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
+	c.Settings["tun"] = map[string]any{"unsafe_routes": []any{
+		map[string]any{"via": "192.168.0.1", "route": "1.0.0.0/28"},
+		map[string]any{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 	}}
 	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	require.NoError(t, err)
@@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) {
 
 	nip, err := netip.ParseAddr("192.168.0.1")
 	require.NoError(t, err)
-	assert.Equal(t, nip, r)
+	assert.Equal(t, nip, r[0].Addr())
 
 	ip, err = netip.ParseAddr("1.0.0.1")
 	require.NoError(t, err)
@@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) {
 
 	nip, err = netip.ParseAddr("192.168.0.2")
 	require.NoError(t, err)
-	assert.Equal(t, nip, r)
+	assert.Equal(t, nip, r[0].Addr())
 
 	ip, err = netip.ParseAddr("1.1.0.1")
 	require.NoError(t, err)
 	r, ok = routeTree.Lookup(ip)
 	assert.False(t, ok)
 }
+
+func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
+	l := test.NewLogger()
+	c := config.NewC(l)
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	require.NoError(t, err)
+
+	c.Settings["tun"] = map[string]any{
+		"unsafe_routes": []any{
+			map[string]any{
+				"route": "192.168.86.0/24",
+				"via":   "192.168.100.10",
+			},
+			map[string]any{
+				"route": "192.168.87.0/24",
+				"via": []any{
+					map[string]any{
+						"gateway": "10.0.0.1",
+					},
+					map[string]any{
+						"gateway": "10.0.0.2",
+					},
+					map[string]any{
+						"gateway": "10.0.0.3",
+					},
+				},
+			},
+			map[string]any{
+				"route": "192.168.89.0/24",
+				"via": []any{
+					map[string]any{
+						"gateway": "10.0.0.1",
+						"weight":  10,
+					},
+					map[string]any{
+						"gateway": "10.0.0.2",
+						"weight":  5,
+					},
+				},
+			},
+		},
+	}
+
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
+	require.NoError(t, err)
+	assert.Len(t, routes, 3)
+	routeTree, err := makeRouteTree(l, routes, true)
+	require.NoError(t, err)
+
+	ip, err := netip.ParseAddr("192.168.86.1")
+	require.NoError(t, err)
+	r, ok := routeTree.Lookup(ip)
+	assert.True(t, ok)
+
+	nip, err := netip.ParseAddr("192.168.100.10")
+	require.NoError(t, err)
+	assert.Equal(t, nip, r[0].Addr())
+
+	ip, err = netip.ParseAddr("192.168.87.1")
+	require.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
+	assert.True(t, ok)
+
+	expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
+		routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
+		routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
+
+	routing.CalculateBucketsForGateways(expectedGateways)
+	assert.ElementsMatch(t, expectedGateways, r)
+
+	ip, err = netip.ParseAddr("192.168.89.1")
+	require.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
+	assert.True(t, ok)
+
+	expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
+		routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
+
+	routing.CalculateBucketsForGateways(expectedGateways)
+	assert.ElementsMatch(t, expectedGateways, r)
+}

+ 3 - 2
overlay/tun_android.go

@@ -13,6 +13,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -21,7 +22,7 @@ type tun struct {
 	fd          int
 	vpnNetworks []netip.Prefix
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 }
 
@@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }

+ 6 - 5
overlay/tun_darwin.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	netroute "golang.org/x/net/route"
 	"golang.org/x/sys/unix"
@@ -28,7 +29,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	DefaultMTU  int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	linkAddr    *netroute.LinkAddr
 	l           *logrus.Logger
 
@@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, ok := t.routeTree.Load().Lookup(ip)
 	if ok {
 		return r
 	}
-	return netip.Addr{}
+	return routing.Gateways{}
 }
 
 // Get the LinkAddr for the interface of the given name
@@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
@@ -393,7 +394,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 				t.l.WithField("route", r.Cidr).
 					Warnf("unable to add unsafe_route, identical route already exists")
 			} else {
-				retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+				retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 				if logErrors {
 					retErr.Log(t.l)
 				} else {

+ 3 - 2
overlay/tun_disabled.go

@@ -9,6 +9,7 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/routing"
 )
 
 type disabledTun struct {
@@ -43,8 +44,8 @@ func (*disabledTun) Activate() error {
 	return nil
 }
 
-func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
-	return netip.Addr{}
+func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
+	return routing.Gateways{}
 }
 
 func (t *disabledTun) Networks() []netip.Prefix {

+ 5 - 4
overlay/tun_freebsd.go

@@ -20,6 +20,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -50,7 +51,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	io.ReadWriteCloser
@@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
@@ -270,7 +271,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
+			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {

+ 3 - 2
overlay/tun_ios.go

@@ -16,6 +16,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -23,7 +24,7 @@ type tun struct {
 	io.ReadWriteCloser
 	vpnNetworks []netip.Prefix
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 }
 
@@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }

+ 70 - 23
overlay/tun_linux.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
@@ -34,7 +35,7 @@ type tun struct {
 	ioctlFd     uintptr
 
 	Routes          atomic.Pointer[[]Route]
-	routeTree       atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree       atomic.Pointer[bart.Table[routing.Gateways]]
 	routeChan       chan struct{}
 	useSystemRoutes bool
 
@@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return file, nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -463,7 +464,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 
 		err := netlink.RouteReplace(&nr)
 		if err != nil {
-			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {
@@ -550,20 +551,7 @@ func (t *tun) watchRoutes() {
 	}()
 }
 
-func (t *tun) updateRoutes(r netlink.RouteUpdate) {
-	if r.Gw == nil {
-		// Not a gateway route, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
-		return
-	}
-
-	gwAddr, ok := netip.AddrFromSlice(r.Gw)
-	if !ok {
-		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
-		return
-	}
-
-	gwAddr = gwAddr.Unmap()
+func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
 	withinNetworks := false
 	for i := range t.vpnNetworks {
 		if t.vpnNetworks[i].Contains(gwAddr) {
@@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 			break
 		}
 	}
-	if !withinNetworks {
-		// Gateway isn't in our overlay network, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
+
+	return withinNetworks
+}
+
+func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
+
+	var gateways routing.Gateways
+
+	link, err := netlink.LinkByName(t.Device)
+	if err != nil {
+		t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
+		return gateways
+	}
+
+	// If this route is relevant to our interface and there is a gateway then add it
+	if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
+		gwAddr, ok := netip.AddrFromSlice(r.Gw)
+		if !ok {
+			t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
+		} else {
+			gwAddr = gwAddr.Unmap()
+
+			if !t.isGatewayInVpnNetworks(gwAddr) {
+				// Gateway isn't in our overlay network, ignore
+				t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+			} else {
+				gateways = append(gateways, routing.NewGateway(gwAddr, 1))
+			}
+		}
+	}
+
+	for _, p := range r.MultiPath {
+		// If this route is relevant to our interface and there is a gateway then add it
+		if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
+			gwAddr, ok := netip.AddrFromSlice(p.Gw)
+			if !ok {
+				t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
+			} else {
+				gwAddr = gwAddr.Unmap()
+
+				if !t.isGatewayInVpnNetworks(gwAddr) {
+					// Gateway isn't in our overlay network, ignore
+					t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+				} else {
+					// p.Hops+1 = weight of the route
+					gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
+				}
+			}
+		}
+	}
+
+	routing.CalculateBucketsForGateways(gateways)
+	return gateways
+}
+
+func (t *tun) updateRoutes(r netlink.RouteUpdate) {
+
+	gateways := t.getGatewaysFromRoute(&r.Route)
+
+	if len(gateways) == 0 {
+		// No gateways relevant to our network, no routing changes required.
+		t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
 		return
 	}
 
@@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	newTree := t.routeTree.Load().Clone()
 
 	if r.Type == unix.RTM_NEWROUTE {
-		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
-		newTree.Insert(dst, gwAddr)
+		t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
+		newTree.Insert(dst, gateways)
 
 	} else {
+		t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
 		newTree.Delete(dst)
-		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 	}
 	t.routeTree.Store(newTree)
 }

+ 5 - 4
overlay/tun_netbsd.go

@@ -18,6 +18,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -31,7 +32,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	io.ReadWriteCloser
@@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
@@ -205,7 +206,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
+			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {

+ 5 - 4
overlay/tun_openbsd.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -25,7 +26,7 @@ type tun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	io.ReadWriteCloser
@@ -158,7 +159,7 @@ func (t *tun) Activate() error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
@@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
@@ -174,7 +175,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
+			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {

+ 3 - 2
overlay/tun_tester.go

@@ -13,13 +13,14 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 )
 
 type TestTun struct {
 	Device      string
 	vpnNetworks []netip.Prefix
 	Routes      []Route
-	routeTree   *bart.Table[netip.Addr]
+	routeTree   *bart.Table[routing.Gateways]
 	l           *logrus.Logger
 
 	closed    atomic.Bool
@@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Lookup(ip)
 	return r
 }

+ 11 - 6
overlay/tun_windows.go

@@ -18,6 +18,7 @@ import (
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/wintun"
 	"golang.org/x/sys/windows"
@@ -31,7 +32,7 @@ type winTun struct {
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
-	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
 
 	tun *wintun.NativeTun
@@ -147,15 +148,18 @@ func (t *winTun) addRoutes(logErrors bool) error {
 	foundDefault4 := false
 
 	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
+		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
 		// Add our unsafe route
-		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
+		// Windows does not support multipath routes natively, so we install only a single route.
+		// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
+		// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
+		err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
 		if err != nil {
-			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 				continue
@@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		err := luid.DeleteRoute(r.Cidr, r.Via)
+		// See comment on luid.AddRoute
+		err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
@@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }

+ 8 - 3
overlay/user.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/routing"
 )
 
 func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
@@ -38,9 +39,13 @@ type UserDevice struct {
 func (d *UserDevice) Activate() error {
 	return nil
 }
-func (d *UserDevice) Networks() []netip.Prefix          { return d.vpnNetworks }
-func (d *UserDevice) Name() string                      { return "faketun0" }
-func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
+
+func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
+func (d *UserDevice) Name() string             { return "faketun0" }
+func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
+	return routing.Gateways{routing.NewGateway(ip, 1)}
+}
+
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return d, nil
 }

+ 4 - 4
punchy_test.go

@@ -27,7 +27,7 @@ func TestNewPunchyFromConfig(t *testing.T) {
 	assert.True(t, p.GetPunch())
 
 	// punchy.punch
-	c.Settings["punchy"] = map[interface{}]interface{}{"punch": true}
+	c.Settings["punchy"] = map[string]any{"punch": true}
 	p = NewPunchyFromConfig(l, c)
 	assert.True(t, p.GetPunch())
 
@@ -37,18 +37,18 @@ func TestNewPunchyFromConfig(t *testing.T) {
 	assert.True(t, p.GetRespond())
 
 	// punchy.respond
-	c.Settings["punchy"] = map[interface{}]interface{}{"respond": true}
+	c.Settings["punchy"] = map[string]any{"respond": true}
 	c.Settings["punch_back"] = false
 	p = NewPunchyFromConfig(l, c)
 	assert.True(t, p.GetRespond())
 
 	// punchy.delay
-	c.Settings["punchy"] = map[interface{}]interface{}{"delay": "1m"}
+	c.Settings["punchy"] = map[string]any{"delay": "1m"}
 	p = NewPunchyFromConfig(l, c)
 	assert.Equal(t, time.Minute, p.GetDelay())
 
 	// punchy.respond_delay
-	c.Settings["punchy"] = map[interface{}]interface{}{"respond_delay": "1m"}
+	c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
 	p = NewPunchyFromConfig(l, c)
 	assert.Equal(t, time.Minute, p.GetRespondDelay())
 }

+ 39 - 0
routing/balance.go

@@ -0,0 +1,39 @@
+package routing
+
+import (
+	"net/netip"
+
+	"github.com/slackhq/nebula/firewall"
+)
+
+// Hashes the packet source and destination port and always returns a positive integer
+// Based on 'Prospecting for Hash Functions'
+//   - https://nullprogram.com/blog/2018/07/31/
+//   - https://github.com/skeeto/hash-prospector
+//     [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
+func hashPacket(p *firewall.Packet) int {
+	x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
+	x ^= x >> 16
+	x *= 0x21f0aaad
+	x ^= x >> 15
+	x *= 0xd35a2d97
+	x ^= x >> 15
+
+	return int(x) & 0x7FFFFFFF
+}
+
+// For this function to work correctly it requires that the buckets for the gateways have been calculated
+// If the contract is violated balancing will not work properly and the second return value will return false
+func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
+	hash := hashPacket(fwPacket)
+
+	for i := range gateways {
+		if hash <= gateways[i].BucketUpperBound() {
+			return gateways[i].Addr(), true
+		}
+	}
+
+	// If you land here then the buckets for the gateways are not properly calculated
+	// Fallback to random routing and let the caller know
+	return gateways[hash%len(gateways)].Addr(), false
+}

+ 144 - 0
routing/balance_test.go

@@ -0,0 +1,144 @@
+package routing
+
+import (
+	"net/netip"
+	"testing"
+
+	"github.com/slackhq/nebula/firewall"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestPacketsAreBalancedEqually(t *testing.T) {
+
+	gateways := []Gateway{}
+
+	gw1Addr := netip.MustParseAddr("1.0.0.1")
+	gw2Addr := netip.MustParseAddr("1.0.0.2")
+	gw3Addr := netip.MustParseAddr("1.0.0.3")
+
+	gateways = append(gateways, NewGateway(gw1Addr, 1))
+	gateways = append(gateways, NewGateway(gw2Addr, 1))
+	gateways = append(gateways, NewGateway(gw3Addr, 1))
+
+	CalculateBucketsForGateways(gateways)
+
+	gw1count := 0
+	gw2count := 0
+	gw3count := 0
+
+	iterationCount := uint16(65535)
+	for i := uint16(0); i < iterationCount; i++ {
+		packet := firewall.Packet{
+			LocalAddr:  netip.MustParseAddr("192.168.1.1"),
+			RemoteAddr: netip.MustParseAddr("10.0.0.1"),
+			LocalPort:  i,
+			RemotePort: 65535 - i,
+			Protocol:   6, // TCP
+			Fragment:   false,
+		}
+
+		selectedGw, ok := BalancePacket(&packet, gateways)
+		assert.True(t, ok)
+
+		switch selectedGw {
+		case gw1Addr:
+			gw1count += 1
+		case gw2Addr:
+			gw2count += 1
+		case gw3Addr:
+			gw3count += 1
+		}
+
+	}
+
+	// Assert packets are balanced, allow variation of up to 100 packets per gateway
+	assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
+	assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
+	assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
+
+}
+
+func TestPacketsAreBalancedByPriority(t *testing.T) {
+
+	gateways := []Gateway{}
+
+	gw1Addr := netip.MustParseAddr("1.0.0.1")
+	gw2Addr := netip.MustParseAddr("1.0.0.2")
+
+	gateways = append(gateways, NewGateway(gw1Addr, 10))
+	gateways = append(gateways, NewGateway(gw2Addr, 5))
+
+	CalculateBucketsForGateways(gateways)
+
+	gw1count := 0
+	gw2count := 0
+
+	iterationCount := uint16(65535)
+	for i := uint16(0); i < iterationCount; i++ {
+		packet := firewall.Packet{
+			LocalAddr:  netip.MustParseAddr("192.168.1.1"),
+			RemoteAddr: netip.MustParseAddr("10.0.0.1"),
+			LocalPort:  i,
+			RemotePort: 65535 - i,
+			Protocol:   6, // TCP
+			Fragment:   false,
+		}
+
+		selectedGw, ok := BalancePacket(&packet, gateways)
+		assert.True(t, ok)
+
+		switch selectedGw {
+		case gw1Addr:
+			gw1count += 1
+		case gw2Addr:
+			gw2count += 1
+		}
+
+	}
+
+	iterationCountAsFloat := float32(iterationCount)
+
+	assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
+	assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
+}
+
+func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
+	gateways := []Gateway{}
+
+	gw1Addr := netip.MustParseAddr("1.0.0.1")
+	gw2Addr := netip.MustParseAddr("1.0.0.2")
+
+	gateways = append(gateways, NewGateway(gw1Addr, 10))
+	gateways = append(gateways, NewGateway(gw2Addr, 5))
+
+	iterationCount := uint16(65535)
+	gw1count := 0
+	gw2count := 0
+
+	for i := uint16(0); i < iterationCount; i++ {
+		packet := firewall.Packet{
+			LocalAddr:  netip.MustParseAddr("192.168.1.1"),
+			RemoteAddr: netip.MustParseAddr("10.0.0.1"),
+			LocalPort:  i,
+			RemotePort: 65535 - i,
+			Protocol:   6, // TCP
+			Fragment:   false,
+		}
+
+		selectedGw, ok := BalancePacket(&packet, gateways)
+		assert.False(t, ok)
+
+		switch selectedGw {
+		case gw1Addr:
+			gw1count += 1
+		case gw2Addr:
+			gw2count += 1
+		}
+
+	}
+
+	assert.Equal(t, int(iterationCount), (gw1count + gw2count))
+	assert.NotEqual(t, 0, gw1count)
+	assert.NotEqual(t, 0, gw2count)
+
+}

+ 70 - 0
routing/gateway.go

@@ -0,0 +1,70 @@
+package routing
+
+import (
+	"fmt"
+	"net/netip"
+)
+
+const (
+	// Sentinal value
+	BucketNotCalculated = -1
+)
+
+type Gateways []Gateway
+
+func (g Gateways) String() string {
+	str := ""
+	for i, gw := range g {
+		str += gw.String()
+		if i < len(g)-1 {
+			str += ", "
+		}
+	}
+	return str
+}
+
+type Gateway struct {
+	addr             netip.Addr
+	weight           int
+	bucketUpperBound int
+}
+
+func NewGateway(addr netip.Addr, weight int) Gateway {
+	return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
+}
+
+func (g *Gateway) BucketUpperBound() int {
+	return g.bucketUpperBound
+}
+
+func (g *Gateway) Addr() netip.Addr {
+	return g.addr
+}
+
+func (g *Gateway) String() string {
+	return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
+}
+
+// Divide and round to nearest integer
+func divideAndRound(v uint64, d uint64) uint64 {
+	var tmp uint64 = v + d/2
+	return tmp / d
+}
+
+// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
+// After this function returns each gateway will have a
+// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
+func CalculateBucketsForGateways(gateways []Gateway) {
+
+	var totalWeight int = 0
+	for i := range gateways {
+		totalWeight += gateways[i].weight
+	}
+
+	var loopWeight int = 0
+	for i := range gateways {
+		loopWeight += gateways[i].weight
+		gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
+	}
+
+}

+ 34 - 0
routing/gateway_test.go

@@ -0,0 +1,34 @@
+package routing
+
+import (
+	"net/netip"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRebalance3_2Split(t *testing.T) {
+	gateways := []Gateway{}
+
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
+
+	CalculateBucketsForGateways(gateways)
+
+	assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
+	assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
+}
+
+func TestRebalanceEqualSplit(t *testing.T) {
+	gateways := []Gateway{}
+
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
+	gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
+
+	CalculateBucketsForGateways(gateways)
+
+	assert.Equal(t, 715827882, gateways[0].bucketUpperBound)  // INT_MAX/3
+	assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
+	assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
+}

+ 2 - 2
service/service_test.go

@@ -13,10 +13,10 @@ import (
 	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
 	"golang.org/x/sync/errgroup"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
 	_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})

+ 46 - 46
ssh.go

@@ -124,10 +124,10 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 	}
 
 	rawKeys := c.Get("sshd.authorized_users")
-	keys, ok := rawKeys.([]interface{})
+	keys, ok := rawKeys.([]any)
 	if ok {
 		for _, rk := range keys {
-			kDef, ok := rk.(map[interface{}]interface{})
+			kDef, ok := rk.(map[string]any)
 			if !ok {
 				l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
 				continue
@@ -148,7 +148,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 					continue
 				}
 
-			case []interface{}:
+			case []any:
 				for _, subK := range v {
 					sk, ok := subK.(string)
 					if !ok {
@@ -190,7 +190,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "list-hostmap",
 		ShortDescription: "List all known previously connected hosts",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
@@ -198,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshListHostMap(f.hostMap, fs, w)
 		},
 	})
@@ -206,7 +206,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "list-pending-hostmap",
 		ShortDescription: "List all handshaking hosts",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
@@ -214,7 +214,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshListHostMap(f.handshakeManager, fs, w)
 		},
 	})
@@ -222,14 +222,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "list-lighthouse-addrmap",
 		ShortDescription: "List all lighthouse map entries",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshListLighthouseMap(f.lightHouse, fs, w)
 		},
 	})
@@ -237,7 +237,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "reload",
 		ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshReload(c, w)
 		},
 	})
@@ -251,7 +251,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "stop-cpu-profile",
 		ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			pprof.StopCPUProfile()
 			return w.WriteLine("If a CPU profile was running it is now stopped")
 		},
@@ -278,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "log-level",
 		ShortDescription: "Gets or sets the current log level",
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshLogLevel(l, fs, a, w)
 		},
 	})
@@ -286,7 +286,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "log-format",
 		ShortDescription: "Gets or sets the current log format",
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshLogFormat(l, fs, a, w)
 		},
 	})
@@ -294,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "version",
 		ShortDescription: "Prints the currently running version of nebula",
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshVersion(f, fs, a, w)
 		},
 	})
@@ -302,14 +302,14 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "device-info",
 		ShortDescription: "Prints information about the network device.",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshDeviceInfoFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshDeviceInfo(f, fs, w)
 		},
 	})
@@ -317,7 +317,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-cert",
 		ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintCertFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json")
@@ -325,7 +325,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 			fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshPrintCert(f, fs, a, w)
 		},
 	})
@@ -333,13 +333,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-tunnel",
 		ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintTunnelFlags{}
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshPrintTunnel(f, fs, a, w)
 		},
 	})
@@ -347,13 +347,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-relays",
 		ShortDescription: "Prints json details about all relay info",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintTunnelFlags{}
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshPrintRelays(f, fs, a, w)
 		},
 	})
@@ -361,13 +361,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "change-remote",
 		ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshChangeRemoteFlags{}
 			fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshChangeRemote(f, fs, a, w)
 		},
 	})
@@ -375,13 +375,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "close-tunnel",
 		ShortDescription: "Closes a tunnel for the provided vpn addr",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshCloseTunnelFlags{}
 			fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshCloseTunnel(f, fs, a, w)
 		},
 	})
@@ -390,13 +390,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 		Name:             "create-tunnel",
 		ShortDescription: "Creates a tunnel for the provided vpn address",
 		Help:             "The lighthouses will be queried for real addresses but you can provide one as well.",
-		Flags: func() (*flag.FlagSet, interface{}) {
+		Flags: func() (*flag.FlagSet, any) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshCreateTunnelFlags{}
 			fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
 			return fl, &s
 		},
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshCreateTunnel(f, fs, a, w)
 		},
 	})
@@ -405,13 +405,13 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 		Name:             "query-lighthouse",
 		ShortDescription: "Query the lighthouses for the provided vpn address",
 		Help:             "This command is asynchronous. Only currently known udp addresses will be printed.",
-		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
+		Callback: func(fs any, a []string, w sshd.StringWriter) error {
 			return sshQueryLighthouse(f, fs, a, w)
 		},
 	})
 }
 
-func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
+func sshListHostMap(hl controlHostLister, a any, w sshd.StringWriter) error {
 	fs, ok := a.(*sshListHostMapFlags)
 	if !ok {
 		return nil
@@ -451,7 +451,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 	return nil
 }
 
-func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
+func sshListLighthouseMap(lightHouse *LightHouse, a any, w sshd.StringWriter) error {
 	fs, ok := a.(*sshListHostMapFlags)
 	if !ok {
 		return nil
@@ -505,7 +505,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	return nil
 }
 
-func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
+func sshStartCpuProfile(fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		err := w.WriteLine("No path to write profile provided")
 		return err
@@ -527,11 +527,11 @@ func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
 	return err
 }
 
-func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshVersion(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	return w.WriteLine(fmt.Sprintf("%s", ifce.version))
 }
 
-func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshQueryLighthouse(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine("No vpn address was provided")
 	}
@@ -553,7 +553,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 	return json.NewEncoder(w.GetWriter()).Encode(cm)
 }
 
-func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshCloseTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	flags, ok := fs.(*sshCloseTunnelFlags)
 	if !ok {
 		return nil
@@ -593,7 +593,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	return w.WriteLine("Closed")
 }
 
-func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	flags, ok := fs.(*sshCreateTunnelFlags)
 	if !ok {
 		return nil
@@ -638,7 +638,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 	return w.WriteLine("Created")
 }
 
-func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshChangeRemote(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	flags, ok := fs.(*sshChangeRemoteFlags)
 	if !ok {
 		return nil
@@ -675,7 +675,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 	return w.WriteLine("Changed")
 }
 
-func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
+func sshGetHeapProfile(fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine("No path to write profile provided")
 	}
@@ -696,7 +696,7 @@ func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
 	return err
 }
 
-func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error {
+func sshMutexProfileFraction(fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		rate := runtime.SetMutexProfileFraction(-1)
 		return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
@@ -711,7 +711,7 @@ func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) er
 	return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
 }
 
-func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
+func sshGetMutexProfile(fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine("No path to write profile provided")
 	}
@@ -735,7 +735,7 @@ func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
 	return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
 }
 
-func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshLogLevel(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
 	}
@@ -749,7 +749,7 @@ func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWrit
 	return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
 }
 
-func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshLogFormat(l *logrus.Logger, fs any, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
 		return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
 	}
@@ -767,7 +767,7 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri
 	return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
 }
 
-func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshPrintCert(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	args, ok := fs.(*sshPrintCertFlags)
 	if !ok {
 		return nil
@@ -822,7 +822,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 	return w.WriteLine(cert.String())
 }
 
-func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshPrintRelays(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	args, ok := fs.(*sshPrintTunnelFlags)
 	if !ok {
 		w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
@@ -919,7 +919,7 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	return nil
 }
 
-func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
+func sshPrintTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) error {
 	args, ok := fs.(*sshPrintTunnelFlags)
 	if !ok {
 		return nil
@@ -951,7 +951,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
 }
 
-func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
+func sshDeviceInfo(ifce *Interface, fs any, w sshd.StringWriter) error {
 
 	data := struct {
 		Name string         `json:"name"`

+ 5 - 5
sshd/command.go

@@ -12,7 +12,7 @@ import (
 
 // CommandFlags is a function called before help or command execution to parse command line flags
 // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
-type CommandFlags func() (*flag.FlagSet, interface{})
+type CommandFlags func() (*flag.FlagSet, any)
 
 // CommandCallback is the function called when your command should execute.
 // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
@@ -21,7 +21,7 @@ type CommandFlags func() (*flag.FlagSet, interface{})
 // w is the writer to use when sending messages back to the client.
 // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
 // where appropriate
-type CommandCallback func(fs interface{}, a []string, w StringWriter) error
+type CommandCallback func(fs any, a []string, w StringWriter) error
 
 type Command struct {
 	Name             string
@@ -34,7 +34,7 @@ type Command struct {
 func execCommand(c *Command, args []string, w StringWriter) error {
 	var (
 		fl *flag.FlagSet
-		fs interface{}
+		fs any
 	)
 
 	if c.Flags != nil {
@@ -85,7 +85,7 @@ func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
 
 func matchCommand(c *radix.Tree, cmd string) []string {
 	cmds := make([]string, 0)
-	c.WalkPrefix(cmd, func(found string, v interface{}) bool {
+	c.WalkPrefix(cmd, func(found string, v any) bool {
 		cmds = append(cmds, found)
 		return false
 	})
@@ -95,7 +95,7 @@ func matchCommand(c *radix.Tree, cmd string) []string {
 
 func allCommands(c *radix.Tree) []*Command {
 	cmds := make([]*Command, 0)
-	c.WalkPrefix("", func(found string, v interface{}) bool {
+	c.WalkPrefix("", func(found string, v any) bool {
 		cmd, ok := v.(*Command)
 		if ok {
 			cmds = append(cmds, cmd)

+ 1 - 1
sshd/server.go

@@ -86,7 +86,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
 	s.RegisterCommand(&Command{
 		Name:             "help",
 		ShortDescription: "prints available commands or help <command> for specific usage info",
-		Callback: func(a interface{}, args []string, w StringWriter) error {
+		Callback: func(a any, args []string, w StringWriter) error {
 			return helpCallback(s.commands, args, w)
 		},
 	})

+ 5 - 5
sshd/session.go

@@ -9,13 +9,13 @@ import (
 	"github.com/armon/go-radix"
 	"github.com/sirupsen/logrus"
 	"golang.org/x/crypto/ssh"
-	"golang.org/x/crypto/ssh/terminal"
+	"golang.org/x/term"
 )
 
 type session struct {
 	l        *logrus.Entry
 	c        *ssh.ServerConn
-	term     *terminal.Terminal
+	term     *term.Terminal
 	commands *radix.Tree
 	exitChan chan bool
 }
@@ -31,7 +31,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
 	s.commands.Insert("logout", &Command{
 		Name:             "logout",
 		ShortDescription: "Ends the current session",
-		Callback: func(a interface{}, args []string, w StringWriter) error {
+		Callback: func(a any, args []string, w StringWriter) error {
 			s.Close()
 			return nil
 		},
@@ -106,8 +106,8 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
 	}
 }
 
-func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
-	term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
+func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
+	term := term.NewTerminal(channel, s.c.User()+"@nebula > ")
 	term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
 		// key 9 is tab
 		if key == 9 {

+ 1 - 1
test/assert.go

@@ -13,7 +13,7 @@ import (
 
 // AssertDeepCopyEqual checks to see if two variables have the same values but DO NOT share any memory
 // There is currently a special case for `time.loc` (as this code traverses into unexported fields)
-func AssertDeepCopyEqual(t *testing.T, a interface{}, b interface{}) {
+func AssertDeepCopyEqual(t *testing.T, a any, b any) {
 	v1 := reflect.ValueOf(a)
 	v2 := reflect.ValueOf(b)
 

+ 4 - 2
test/tun.go

@@ -4,12 +4,14 @@ import (
 	"errors"
 	"io"
 	"net/netip"
+
+	"github.com/slackhq/nebula/routing"
 )
 
 type NoopTun struct{}
 
-func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
-	return netip.Addr{}
+func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
+	return routing.Gateways{}
 }
 
 func (NoopTun) Activate() error {

+ 2 - 2
util/error.go

@@ -9,11 +9,11 @@ import (
 
 type ContextualError struct {
 	RealError error
-	Fields    map[string]interface{}
+	Fields    map[string]any
 	Context   string
 }
 
-func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
+func NewContextualError(msg string, fields map[string]any, realError error) *ContextualError {
 	return &ContextualError{Context: msg, Fields: fields, RealError: realError}
 }
 

+ 1 - 1
util/error_test.go

@@ -9,7 +9,7 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 type TestLogWriter struct {
 	Logs []string