Преглед на файлове

Merge tag 'v1.8.2' into multiport

1.8.2 Release
Wade Simmons преди 1 година
родител
ревизия
659d7fece6
променени са 71 файла, в които са добавени 1783 реда и са изтрити 857 реда
  1. 1 1
      .github/workflows/gofmt.yml
  2. 3 3
      .github/workflows/release.yml
  3. 2 2
      .github/workflows/smoke.yml
  4. 6 3
      .github/workflows/test.yml
  5. 74 1
      CHANGELOG.md
  6. 8 1
      Makefile
  7. 12 1
      README.md
  8. 14 29
      allow_list.go
  9. 1 1
      allow_list_test.go
  10. 2 2
      calculated_remote.go
  11. 34 28
      cidr/tree4.go
  12. 71 47
      cidr/tree4_test.go
  13. 24 20
      cidr/tree6.go
  14. 45 28
      cidr/tree6_test.go
  15. 7 8
      cmd/nebula-cert/ca.go
  16. 5 6
      cmd/nebula-cert/ca_test.go
  17. 2 3
      cmd/nebula-cert/keygen.go
  18. 4 5
      cmd/nebula-cert/keygen_test.go
  19. 2 3
      cmd/nebula-cert/print.go
  20. 1 2
      cmd/nebula-cert/print_test.go
  21. 6 7
      cmd/nebula-cert/sign.go
  22. 11 12
      cmd/nebula-cert/sign_test.go
  23. 2 3
      cmd/nebula-cert/verify.go
  24. 2 3
      cmd/nebula-cert/verify_test.go
  25. 5 2
      config/config.go
  26. 7 8
      config/config_test.go
  27. 10 5
      connection_manager.go
  28. 12 11
      connection_manager_test.go
  29. 0 3
      connection_state.go
  30. 10 2
      control.go
  31. 1 2
      control_test.go
  32. 18 18
      e2e/handshakes_test.go
  33. 118 0
      e2e/helpers.go
  34. 1 110
      e2e/helpers_test.go
  35. 5 1
      examples/config.yml
  36. 100 0
      examples/go_service/main.go
  37. 36 14
      firewall.go
  38. 8 4
      firewall_test.go
  39. 14 11
      go.mod
  40. 31 24
      go.sum
  41. 0 31
      handshake.go
  42. 47 58
      handshake_ix.go
  43. 170 94
      handshake_manager.go
  44. 10 1
      handshake_manager_test.go
  45. 25 82
      hostmap.go
  46. 26 12
      inside.go
  47. 17 7
      interface.go
  48. 34 7
      iputil/packet.go
  49. 73 0
      iputil/packet_test.go
  50. 56 28
      lighthouse.go
  51. 23 4
      main.go
  52. 4 5
      outside.go
  53. 2 2
      overlay/route.go
  54. 8 10
      overlay/route_test.go
  55. 24 8
      overlay/tun.go
  56. 3 7
      overlay/tun_android.go
  57. 4 4
      overlay/tun_darwin.go
  58. 3 7
      overlay/tun_freebsd.go
  59. 3 7
      overlay/tun_ios.go
  60. 5 9
      overlay/tun_linux.go
  61. 3 7
      overlay/tun_netbsd.go
  62. 3 7
      overlay/tun_openbsd.go
  63. 3 7
      overlay/tun_tester.go
  64. 3 7
      overlay/tun_water_windows.go
  65. 3 7
      overlay/tun_wintun_windows.go
  66. 63 0
      overlay/user.go
  67. 36 0
      service/listener.go
  68. 248 0
      service/service.go
  69. 165 0
      service/service_test.go
  70. 2 3
      ssh.go
  71. 2 2
      test/logger.go

+ 1 - 1
.github/workflows/gofmt.yml

@@ -16,7 +16,7 @@ jobs:
 
     - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true

+ 3 - 3
.github/workflows/release.yml

@@ -12,7 +12,7 @@ jobs:
     steps:
       - uses: actions/checkout@v4
 
-      - uses: actions/setup-go@v4
+      - uses: actions/setup-go@v5
         with:
           go-version-file: 'go.mod'
           check-latest: true
@@ -35,7 +35,7 @@ jobs:
     steps:
       - uses: actions/checkout@v4
 
-      - uses: actions/setup-go@v4
+      - uses: actions/setup-go@v5
         with:
           go-version-file: 'go.mod'
           check-latest: true
@@ -68,7 +68,7 @@ jobs:
     steps:
       - uses: actions/checkout@v4
 
-      - uses: actions/setup-go@v4
+      - uses: actions/setup-go@v5
         with:
           go-version-file: 'go.mod'
           check-latest: true

+ 2 - 2
.github/workflows/smoke.yml

@@ -20,13 +20,13 @@ jobs:
 
     - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true
 
     - name: build
-      run: make bin-docker
+      run: make bin-docker CGO_ENABLED=1 BUILD_ARGS=-race
 
     - name: setup docker image
       working-directory: ./.github/workflows/smoke

+ 6 - 3
.github/workflows/test.yml

@@ -20,7 +20,7 @@ jobs:
 
     - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true
@@ -37,6 +37,9 @@ jobs:
     - name: End 2 end
       run: make e2evv
 
+    - name: Build test mobile
+      run: make build-test-mobile
+
     - uses: actions/upload-artifact@v3
       with:
         name: e2e packet flow
@@ -50,7 +53,7 @@ jobs:
 
     - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true
@@ -74,7 +77,7 @@ jobs:
 
     - uses: actions/checkout@v4
 
-    - uses: actions/setup-go@v4
+    - uses: actions/setup-go@v5
       with:
         go-version-file: 'go.mod'
         check-latest: true

+ 74 - 1
CHANGELOG.md

@@ -7,6 +7,76 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+## [1.8.2] - 2024-01-08
+
+### Fixed
+
+- Fix multiple routines when listen.port is zero. This was a regression
+  introduced in v1.6.0. (#1057)
+
+### Changed
+
+- Small dependency update for Noise. (#1038)
+
+## [1.8.1] - 2023-12-19
+
+### Security
+
+- Update `golang.org/x/crypto`, which includes a fix for CVE-2023-48795. (#1048)
+
+### Fixed
+
+- Fix a deadlock introduced in v1.8.0 that could occur during handshakes. (#1044)
+
+- Fix mobile builds. (#1035)
+
+## [1.8.0] - 2023-12-06
+
+### Deprecated
+
+- The next minor release of Nebula, 1.9.0, will require at least Windows 10 or
+  Windows Server 2016. This is because support for earlier versions was removed
+  in Go 1.21. See https://go.dev/doc/go1.21#windows
+
+### Added
+
+- Linux: Notify systemd of service readiness. This should resolve timing issues
+  with services that depend on Nebula being active. For an example of how to
+  enable this, see: `examples/service_scripts/nebula.service`. (#929)
+
+- Windows: Use Registered IO (RIO) when possible. Testing on a Windows 11
+  machine shows ~50x improvement in throughput. (#905)
+
+- NetBSD, OpenBSD: Added rudimentary support. (#916, #812)
+
+- FreeBSD: Add support for naming tun devices. (#903)
+
+### Changed
+
+- `pki.disconnect_invalid` will now default to true. This means that once a
+  certificate expires, the tunnel will be disconnected. If you use SIGHUP to
+  reload certificates without restarting Nebula, you should ensure all of your
+  clients are on 1.7.0 or newer before you enable this feature. (#859)
+
+- Limit how often a busy tunnel can requery the lighthouse. The new config
+  option `timers.requery_wait_duration` defaults to `60s`. (#940)
+
+- The internal structures for hostmaps were refactored to reduce memory usage
+  and the potential for subtle bugs. (#843, #938, #953, #954, #955)
+
+- Lots of dependency updates.
+
+### Fixed
+
+- Windows: Retry wintun device creation if it fails the first time. (#985)
+
+- Fix issues with firewall reject packets that could cause panics. (#957)
+
+- Fix relay migration during re-handshakes. (#964)
+
+- Various other refactors and fixes. (#935, #952, #972, #961, #996, #1002,
+  #987, #1004, #1030, #1032, ...)
+
 ## [1.7.2] - 2023-06-01
 
 ### Fixed
@@ -488,7 +558,10 @@ created.)
 
 - Initial public release.
 
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.7.2...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.8.2...HEAD
+[1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2
+[1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1
+[1.8.0]: https://github.com/slackhq/nebula/releases/tag/v1.8.0
 [1.7.2]: https://github.com/slackhq/nebula/releases/tag/v1.7.2
 [1.7.1]: https://github.com/slackhq/nebula/releases/tag/v1.7.1
 [1.7.0]: https://github.com/slackhq/nebula/releases/tag/v1.7.0

+ 8 - 1
Makefile

@@ -169,6 +169,12 @@ test-cov-html:
 	go test -coverprofile=coverage.out
 	go tool cover -html=coverage.out
 
+build-test-mobile:
+	GOARCH=amd64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/')
+	GOARCH=arm64 GOOS=ios go build $(shell go list ./... | grep -v '/cmd/\|/examples/')
+	GOARCH=amd64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/')
+	GOARCH=arm64 GOOS=android go build $(shell go list ./... | grep -v '/cmd/\|/examples/')
+
 bench:
 	go test -bench=.
 
@@ -214,8 +220,9 @@ smoke-multiport-docker: bin-docker
 	cd .github/workflows/smoke/ && NAME="smoke-multiport" ./smoke.sh
 
 smoke-docker-race: BUILD_ARGS = -race
+smoke-docker-race: CGO_ENABLED = 1
 smoke-docker-race: smoke-docker
 
 .FORCE:
-.PHONY: e2e e2ev e2evv e2evvv e2evvvv test test-cov-html bench bench-cpu bench-cpu-long bin proto release service smoke-docker smoke-docker-race
+.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html
 .DEFAULT_GOAL := bin

+ 12 - 1
README.md

@@ -27,15 +27,26 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
 
 #### Distribution Packages
 
-- [Arch Linux](https://archlinux.org/packages/community/x86_64/nebula/)
+- [Arch Linux](https://archlinux.org/packages/extra/x86_64/nebula/)
     ```
     $ sudo pacman -S nebula
     ```
+
 - [Fedora Linux](https://src.fedoraproject.org/rpms/nebula)
     ```
     $ sudo dnf install nebula
     ```
 
+- [Debian Linux](https://packages.debian.org/source/stable/nebula)
+    ```
+    $ sudo apt install nebula
+    ```
+
+- [Alpine Linux](https://pkgs.alpinelinux.org/packages?name=nebula)
+    ```
+    $ sudo apk add nebula
+    ```
+
 - [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb)
     ```
     $ brew install nebula

+ 14 - 29
allow_list.go

@@ -12,7 +12,7 @@ import (
 
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *cidr.Tree6
+	cidrTree *cidr.Tree6[bool]
 }
 
 type RemoteAllowList struct {
@@ -20,7 +20,7 @@ type RemoteAllowList struct {
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
-	insideAllowLists *cidr.Tree6
+	insideAllowLists *cidr.Tree6[*AllowList]
 }
 
 type LocalAllowList struct {
@@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
 
-	tree := cidr.NewTree6()
+	tree := cidr.NewTree6[bool]()
 
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	type allowListRules struct {
@@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
 	return nameRules, nil
 }
 
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
+func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	remoteAllowRanges := cidr.NewTree6()
+	remoteAllowRanges := cidr.NewTree6[*AllowList]()
 
 	rawMap, ok := value.(map[interface{}]interface{})
 	if !ok {
@@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContains(ip)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContains(ip)
+	return result
 }
 
 func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
@@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContainsIpV4(ip)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
+	return result
 }
 
 func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
@@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+	return result
 }
 
 func (al *LocalAllowList) Allow(ip net.IP) bool {
@@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
 
 func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
 	if al.insideAllowLists != nil {
-		inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
-		if inside != nil {
-			return inside.(*AllowList)
+		ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+		if ok {
+			return inside
 		}
 	}
 	return nil

+ 1 - 1
allow_list_test.go

@@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 func TestAllowList_Allow(t *testing.T) {
 	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
 
-	tree := cidr.NewTree6()
+	tree := cidr.NewTree6[bool]()
 	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
 	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
 	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)

+ 2 - 2
calculated_remote.go

@@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
 	return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
 }
 
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	calculatedRemotes := cidr.NewTree4()
+	calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
 
 	rawMap, ok := value.(map[any]any)
 	if !ok {

+ 34 - 28
cidr/tree4.go

@@ -6,35 +6,36 @@ import (
 	"github.com/slackhq/nebula/iputil"
 )
 
-type Node struct {
-	left   *Node
-	right  *Node
-	parent *Node
-	value  interface{}
+type Node[T any] struct {
+	left     *Node[T]
+	right    *Node[T]
+	parent   *Node[T]
+	hasValue bool
+	value    T
 }
 
-type entry struct {
+type entry[T any] struct {
 	CIDR  *net.IPNet
-	Value *interface{}
+	Value T
 }
 
-type Tree4 struct {
-	root *Node
-	list []entry
+type Tree4[T any] struct {
+	root *Node[T]
+	list []entry[T]
 }
 
 const (
 	startbit = iputil.VpnIp(0x80000000)
 )
 
-func NewTree4() *Tree4 {
-	tree := new(Tree4)
-	tree.root = &Node{}
-	tree.list = []entry{}
+func NewTree4[T any]() *Tree4[T] {
+	tree := new(Tree4[T])
+	tree.root = &Node[T]{}
+	tree.list = []entry[T]{}
 	return tree
 }
 
-func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
+func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
 	bit := startbit
 	node := tree.root
 	next := tree.root
@@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 			}
 		}
 
-		tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
+		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
 		node.value = val
+		node.hasValue = true
 		return
 	}
 
 	// Build up the rest of the tree we don't already have
 	for bit&mask != 0 {
-		next = &Node{}
+		next = &Node[T]{}
 		next.parent = node
 
 		if ip&bit != 0 {
@@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
-	tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
+	node.hasValue = true
+	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
 }
 
 // Contains finds the first match, which may be the least specific
-func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 
 	for node != nil {
-		if node.value != nil {
-			return node.value
+		if node.hasValue {
+			return true, node.value
 		}
 
 		if ip&bit != 0 {
@@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 
 	}
 
-	return value
+	return false, value
 }
 
 // MostSpecificContains finds the most specific match
-func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 
 	for node != nil {
-		if node.value != nil {
+		if node.hasValue {
 			value = node.value
+			ok = true
 		}
 
 		if ip&bit != 0 {
@@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 		bit >>= 1
 	}
 
-	return value
+	return ok, value
 }
 
 // Match finds the most specific match
-func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
+// TODO this is exact match
+func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 	lastNode := node
@@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 
 	if bit == 0 && lastNode != nil {
 		value = lastNode.value
+		ok = true
 	}
-	return value
+	return ok, value
 }
 
 // List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4) List() []entry {
+func (tree *Tree4[T]) List() []entry[T] {
 	return tree.list
 }

+ 71 - 47
cidr/tree4_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestCIDRTree_List(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
 	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
@@ -17,13 +17,13 @@ func TestCIDRTree_List(t *testing.T) {
 	list := tree.List()
 	assert.Len(t, list, 2)
 	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
-	assert.Equal(t, "2", *list[0].Value)
+	assert.Equal(t, "2", list[0].Value)
 	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
-	assert.Equal(t, "4", *list[1].Value)
+	assert.Equal(t, "4", list[1].Value)
 }
 
 func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -33,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) {
 	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4a", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4a", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -71,59 +79,75 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
 	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4b", "4.1.1.2"},
+		{true, "4c", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func TestCIDRTree_Match(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
 	tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1a", "4.1.1.0"},
-		{"1b", "4.1.1.1"},
+		{true, "1a", "4.1.1.0"},
+		{true, "1b", "4.1.1.1"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
 	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
@@ -145,7 +169,7 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
 }
 
 func BenchmarkCIDRTree_Match(b *testing.B) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
 	tree.AddCIDR(Parse("192.2.1.1/32"), "1")

+ 24 - 20
cidr/tree6.go

@@ -8,20 +8,20 @@ import (
 
 const startbit6 = uint64(1 << 63)
 
-type Tree6 struct {
-	root4 *Node
-	root6 *Node
+type Tree6[T any] struct {
+	root4 *Node[T]
+	root6 *Node[T]
 }
 
-func NewTree6() *Tree6 {
-	tree := new(Tree6)
-	tree.root4 = &Node{}
-	tree.root6 = &Node{}
+func NewTree6[T any]() *Tree6[T] {
+	tree := new(Tree6[T])
+	tree.root4 = &Node[T]{}
+	tree.root6 = &Node[T]{}
 	return tree
 }
 
-func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
-	var node, next *Node
+func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
+	var node, next *Node[T]
 
 	cidrIP, ipv4 := isIPV4(cidr.IP)
 	if ipv4 {
@@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 		// Build up the rest of the tree we don't already have
 		for bit&mask != 0 {
-			next = &Node{}
+			next = &Node[T]{}
 			next.parent = node
 
 			if ip&bit != 0 {
@@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
+	node.hasValue = true
 }
 
 // Finds the most specific match
-func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
-	var node *Node
+func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
+	var node *Node[T]
 
 	wholeIP, ipv4 := isIPV4(ip)
 	if ipv4 {
@@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
 		bit := startbit
 
 		for node != nil {
-			if node.value != nil {
+			if node.hasValue {
 				value = node.value
+				ok = true
 			}
 
 			if bit == 0 {
@@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
 		}
 	}
 
-	return value
+	return ok, value
 }
 
-func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root4
 
 	for node != nil {
-		if node.value != nil {
+		if node.hasValue {
 			value = node.value
+			ok = true
 		}
 
 		if ip&bit != 0 {
@@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{})
 		bit >>= 1
 	}
 
-	return value
+	return ok, value
 }
 
-func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
+func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
 	ip := hi
 	node := tree.root6
 
@@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 		bit := startbit6
 
 		for node != nil {
-			if node.value != nil {
+			if node.hasValue {
 				value = node.value
+				ok = true
 			}
 
 			if bit == 0 {
@@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 		ip = lo
 	}
 
-	return value
+	return ok, value
 }
 
 func isIPV4(ip net.IP) (net.IP, bool) {

+ 45 - 28
cidr/tree6_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewTree6()
+	tree := NewTree6[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{"6a", "1:2:0:4:1:1:1:1"},
-		{"6b", "1:2:0:4:5:1:1:1"},
-		{"6c", "1:2:0:4:5:0:0:0"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4b", "4.1.1.2"},
+		{true, "4c", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{true, "6a", "1:2:0:4:1:1:1:1"},
+		{true, "6b", "1:2:0:4:5:1:1:1"},
+		{true, "6c", "1:2:0:4:5:0:0:0"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
+		ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree6()
+	tree = NewTree6[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
 	tree.AddCIDR(Parse("::/0"), "cool6")
-	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
-	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
-	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
-	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")))
+	ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("::"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool6", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool6", r)
 }
 
 func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewTree6()
+	tree := NewTree6[string]()
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"6a", "1:2:0:4:1:1:1:1"},
-		{"6b", "1:2:0:4:5:1:1:1"},
-		{"6c", "1:2:0:4:5:0:0:0"},
+		{true, "6a", "1:2:0:4:1:1:1:1"},
+		{true, "6b", "1:2:0:4:5:1:1:1"},
+		{true, "6c", "1:2:0:4:5:0:0:0"},
 	}
 
 	for _, tt := range tests {
@@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
 		hi := binary.BigEndian.Uint64(ip[:8])
 		lo := binary.BigEndian.Uint64(ip[8:])
 
-		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
+		ok, r := tree.MostSpecificContainsIpV6(hi, lo)
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 }

+ 7 - 8
cmd/nebula-cert/ca.go

@@ -7,7 +7,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"math"
 	"net"
 	"os"
@@ -213,27 +212,27 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return fmt.Errorf("error while signing: %s", err)
 	}
 
+	var b []byte
 	if *cf.encryption {
-		b, err := cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+		b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
 		if err != nil {
 			return fmt.Errorf("error while encrypting out-key: %s", err)
 		}
-
-		err = ioutil.WriteFile(*cf.outKeyPath, b, 0600)
 	} else {
-		err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalSigningPrivateKey(curve, rawPriv), 0600)
+		b = cert.MarshalSigningPrivateKey(curve, rawPriv)
 	}
 
+	err = os.WriteFile(*cf.outKeyPath, b, 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-key: %s", err)
 	}
 
-	b, err := nc.MarshalToPEM()
+	b, err = nc.MarshalToPEM()
 	if err != nil {
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}
 
-	err = ioutil.WriteFile(*cf.outCertPath, b, 0600)
+	err = os.WriteFile(*cf.outCertPath, b, 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-crt: %s", err)
 	}
@@ -244,7 +243,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 			return fmt.Errorf("error while generating qr code: %s", err)
 		}
 
-		err = ioutil.WriteFile(*cf.outQRPath, b, 0600)
+		err = os.WriteFile(*cf.outQRPath, b, 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-qr: %s", err)
 		}

+ 5 - 6
cmd/nebula-cert/ca_test.go

@@ -7,7 +7,6 @@ import (
 	"bytes"
 	"encoding/pem"
 	"errors"
-	"io/ioutil"
 	"os"
 	"strings"
 	"testing"
@@ -107,7 +106,7 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp key file
-	keyF, err := ioutil.TempFile("", "test.key")
+	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	os.Remove(keyF.Name())
 
@@ -120,7 +119,7 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp cert file
-	crtF, err := ioutil.TempFile("", "test.crt")
+	crtF, err := os.CreateTemp("", "test.crt")
 	assert.Nil(t, err)
 	os.Remove(crtF.Name())
 	os.Remove(keyF.Name())
@@ -134,13 +133,13 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// read cert and key files
-	rb, _ := ioutil.ReadFile(keyF.Name())
+	rb, _ := os.ReadFile(keyF.Name())
 	lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 64)
 
-	rb, _ = ioutil.ReadFile(crtF.Name())
+	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
@@ -166,7 +165,7 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// read encrypted key file and verify default params
-	rb, _ = ioutil.ReadFile(keyF.Name())
+	rb, _ = os.ReadFile(keyF.Name())
 	k, _ := pem.Decode(rb)
 	ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
 	assert.Nil(t, err)

+ 2 - 3
cmd/nebula-cert/keygen.go

@@ -4,7 +4,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 
 	"github.com/slackhq/nebula/cert"
@@ -54,12 +53,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("invalid curve: %s", *cf.curve)
 	}
 
-	err = ioutil.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+	err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-key: %s", err)
 	}
 
-	err = ioutil.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
+	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
 	}

+ 4 - 5
cmd/nebula-cert/keygen_test.go

@@ -2,7 +2,6 @@ package main
 
 import (
 	"bytes"
-	"io/ioutil"
 	"os"
 	"testing"
 
@@ -54,7 +53,7 @@ func Test_keygen(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp key file
-	keyF, err := ioutil.TempFile("", "test.key")
+	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	defer os.Remove(keyF.Name())
 
@@ -67,7 +66,7 @@ func Test_keygen(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// create temp pub file
-	pubF, err := ioutil.TempFile("", "test.pub")
+	pubF, err := os.CreateTemp("", "test.pub")
 	assert.Nil(t, err)
 	defer os.Remove(pubF.Name())
 
@@ -80,13 +79,13 @@ func Test_keygen(t *testing.T) {
 	assert.Equal(t, "", eb.String())
 
 	// read cert and key files
-	rb, _ := ioutil.ReadFile(keyF.Name())
+	rb, _ := os.ReadFile(keyF.Name())
 	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 
-	rb, _ = ioutil.ReadFile(pubF.Name())
+	rb, _ = os.ReadFile(pubF.Name())
 	lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)

+ 2 - 3
cmd/nebula-cert/print.go

@@ -5,7 +5,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 	"strings"
 
@@ -41,7 +40,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	rawCert, err := ioutil.ReadFile(*pf.path)
+	rawCert, err := os.ReadFile(*pf.path)
 	if err != nil {
 		return fmt.Errorf("unable to read cert; %s", err)
 	}
@@ -87,7 +86,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 			return fmt.Errorf("error while generating qr code: %s", err)
 		}
 
-		err = ioutil.WriteFile(*pf.outQRPath, b, 0600)
+		err = os.WriteFile(*pf.outQRPath, b, 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-qr: %s", err)
 		}

+ 1 - 2
cmd/nebula-cert/print_test.go

@@ -2,7 +2,6 @@ package main
 
 import (
 	"bytes"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -54,7 +53,7 @@ func Test_printCert(t *testing.T) {
 	// invalid cert at path
 	ob.Reset()
 	eb.Reset()
-	tf, err := ioutil.TempFile("", "print-cert")
+	tf, err := os.CreateTemp("", "print-cert")
 	assert.Nil(t, err)
 	defer os.Remove(tf.Name())
 

+ 6 - 7
cmd/nebula-cert/sign.go

@@ -6,7 +6,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"net"
 	"os"
 	"strings"
@@ -73,7 +72,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 
-	rawCAKey, err := ioutil.ReadFile(*sf.caKeyPath)
+	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
 	if err != nil {
 		return fmt.Errorf("error while reading ca-key: %s", err)
 	}
@@ -112,7 +111,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 
-	rawCACert, err := ioutil.ReadFile(*sf.caCertPath)
+	rawCACert, err := os.ReadFile(*sf.caCertPath)
 	if err != nil {
 		return fmt.Errorf("error while reading ca-crt: %s", err)
 	}
@@ -178,7 +177,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 	var pub, rawPriv []byte
 	if *sf.inPubPath != "" {
-		rawPub, err := ioutil.ReadFile(*sf.inPubPath)
+		rawPub, err := os.ReadFile(*sf.inPubPath)
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
@@ -235,7 +234,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}
 
-		err = ioutil.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-key: %s", err)
 		}
@@ -246,7 +245,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}
 
-	err = ioutil.WriteFile(*sf.outCertPath, b, 0600)
+	err = os.WriteFile(*sf.outCertPath, b, 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-crt: %s", err)
 	}
@@ -257,7 +256,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 			return fmt.Errorf("error while generating qr code: %s", err)
 		}
 
-		err = ioutil.WriteFile(*sf.outQRPath, b, 0600)
+		err = os.WriteFile(*sf.outQRPath, b, 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-qr: %s", err)
 		}

+ 11 - 12
cmd/nebula-cert/sign_test.go

@@ -7,7 +7,6 @@ import (
 	"bytes"
 	"crypto/rand"
 	"errors"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -104,7 +103,7 @@ func Test_signCert(t *testing.T) {
 	// failed to unmarshal key
 	ob.Reset()
 	eb.Reset()
-	caKeyF, err := ioutil.TempFile("", "sign-cert.key")
+	caKeyF, err := os.CreateTemp("", "sign-cert.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 
@@ -128,7 +127,7 @@ func Test_signCert(t *testing.T) {
 	// failed to unmarshal cert
 	ob.Reset()
 	eb.Reset()
-	caCrtF, err := ioutil.TempFile("", "sign-cert.crt")
+	caCrtF, err := os.CreateTemp("", "sign-cert.crt")
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 
@@ -159,7 +158,7 @@ func Test_signCert(t *testing.T) {
 	// failed to unmarshal pub
 	ob.Reset()
 	eb.Reset()
-	inPubF, err := ioutil.TempFile("", "in.pub")
+	inPubF, err := os.CreateTemp("", "in.pub")
 	assert.Nil(t, err)
 	defer os.Remove(inPubF.Name())
 
@@ -206,7 +205,7 @@ func Test_signCert(t *testing.T) {
 
 	// mismatched ca key
 	_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
-	caKeyF2, err := ioutil.TempFile("", "sign-cert-2.key")
+	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF2.Name())
 	caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
@@ -227,7 +226,7 @@ func Test_signCert(t *testing.T) {
 	assert.Empty(t, eb.String())
 
 	// create temp key file
-	keyF, err := ioutil.TempFile("", "test.key")
+	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	os.Remove(keyF.Name())
 
@@ -241,7 +240,7 @@ func Test_signCert(t *testing.T) {
 	os.Remove(keyF.Name())
 
 	// create temp cert file
-	crtF, err := ioutil.TempFile("", "test.crt")
+	crtF, err := os.CreateTemp("", "test.crt")
 	assert.Nil(t, err)
 	os.Remove(crtF.Name())
 
@@ -254,13 +253,13 @@ func Test_signCert(t *testing.T) {
 	assert.Empty(t, eb.String())
 
 	// read cert and key files
-	rb, _ := ioutil.ReadFile(keyF.Name())
+	rb, _ := os.ReadFile(keyF.Name())
 	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 
-	rb, _ = ioutil.ReadFile(crtF.Name())
+	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
@@ -296,7 +295,7 @@ func Test_signCert(t *testing.T) {
 	assert.Empty(t, eb.String())
 
 	// read cert file and check pub key matches in-pub
-	rb, _ = ioutil.ReadFile(crtF.Name())
+	rb, _ = os.ReadFile(crtF.Name())
 	lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
@@ -348,11 +347,11 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	caKeyF, err = ioutil.TempFile("", "sign-cert.key")
+	caKeyF, err = os.CreateTemp("", "sign-cert.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 
-	caCrtF, err = ioutil.TempFile("", "sign-cert.crt")
+	caCrtF, err = os.CreateTemp("", "sign-cert.crt")
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 

+ 2 - 3
cmd/nebula-cert/verify.go

@@ -4,7 +4,6 @@ import (
 	"flag"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"os"
 	"strings"
 	"time"
@@ -40,7 +39,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	rawCACert, err := ioutil.ReadFile(*vf.caPath)
+	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
 		return fmt.Errorf("error while reading ca: %s", err)
 	}
@@ -57,7 +56,7 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 		}
 	}
 
-	rawCert, err := ioutil.ReadFile(*vf.certPath)
+	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
 		return fmt.Errorf("unable to read crt; %s", err)
 	}

+ 2 - 3
cmd/nebula-cert/verify_test.go

@@ -3,7 +3,6 @@ package main
 import (
 	"bytes"
 	"crypto/rand"
-	"io/ioutil"
 	"os"
 	"testing"
 	"time"
@@ -56,7 +55,7 @@ func Test_verify(t *testing.T) {
 	// invalid ca at path
 	ob.Reset()
 	eb.Reset()
-	caFile, err := ioutil.TempFile("", "verify-ca")
+	caFile, err := os.CreateTemp("", "verify-ca")
 	assert.Nil(t, err)
 	defer os.Remove(caFile.Name())
 
@@ -92,7 +91,7 @@ func Test_verify(t *testing.T) {
 	// invalid crt at path
 	ob.Reset()
 	eb.Reset()
-	certFile, err := ioutil.TempFile("", "verify-cert")
+	certFile, err := os.CreateTemp("", "verify-cert")
 	assert.Nil(t, err)
 	defer os.Remove(certFile.Name())
 

+ 5 - 2
config/config.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"errors"
 	"fmt"
-	"io/ioutil"
 	"math"
 	"os"
 	"os/signal"
@@ -122,6 +121,10 @@ func (c *C) HasChanged(k string) bool {
 // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
 // original path provided to Load. The old settings are shallow copied for change detection after the reload.
 func (c *C) CatchHUP(ctx context.Context) {
+	if c.path == "" {
+		return
+	}
+
 	ch := make(chan os.Signal, 1)
 	signal.Notify(ch, syscall.SIGHUP)
 
@@ -358,7 +361,7 @@ func (c *C) parse() error {
 	var m map[interface{}]interface{}
 
 	for _, path := range c.files {
-		b, err := ioutil.ReadFile(path)
+		b, err := os.ReadFile(path)
 		if err != nil {
 			return err
 		}

+ 7 - 8
config/config_test.go

@@ -1,7 +1,6 @@
 package config
 
 import (
-	"io/ioutil"
 	"os"
 	"path/filepath"
 	"testing"
@@ -16,10 +15,10 @@ import (
 
 func TestConfig_Load(t *testing.T) {
 	l := test.NewLogger()
-	dir, err := ioutil.TempDir("", "config-test")
+	dir, err := os.MkdirTemp("", "config-test")
 	// invalid yaml
 	c := NewC(l)
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
 	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
 
 	// simple multi config merge
@@ -29,8 +28,8 @@ func TestConfig_Load(t *testing.T) {
 
 	assert.Nil(t, err)
 
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
-	ioutil.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
+	os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
 	assert.Nil(t, c.Load(dir))
 	expected := map[interface{}]interface{}{
 		"outer": map[interface{}]interface{}{
@@ -120,9 +119,9 @@ func TestConfig_HasChanged(t *testing.T) {
 func TestConfig_ReloadConfig(t *testing.T) {
 	l := test.NewLogger()
 	done := make(chan bool, 1)
-	dir, err := ioutil.TempDir("", "config-test")
+	dir, err := os.MkdirTemp("", "config-test")
 	assert.Nil(t, err)
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
 	c := NewC(l)
 	assert.Nil(t, c.Load(dir))
@@ -131,7 +130,7 @@ func TestConfig_ReloadConfig(t *testing.T) {
 	assert.False(t, c.HasChanged("outer"))
 	assert.False(t, c.HasChanged(""))
 
-	ioutil.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
+	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: ho"), 0644)
 
 	c.RegisterReloadCallback(func(c *C) {
 		done <- true

+ 10 - 5
connection_manager.go

@@ -23,6 +23,7 @@ const (
 	swapPrimary    trafficDecision = 3
 	migrateRelays  trafficDecision = 4
 	tryRehandshake trafficDecision = 5
+	sendTestPacket trafficDecision = 6
 )
 
 type connectionManager struct {
@@ -176,7 +177,7 @@ func (n *connectionManager) Run(ctx context.Context) {
 }
 
 func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
+	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
 
 	switch decision {
 	case deleteTunnel:
@@ -197,6 +198,9 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 
 	case tryRehandshake:
 		n.tryRehandshake(hostinfo)
+
+	case sendTestPacket:
+		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 	}
 
 	n.resetRelayTrafficCheck(hostinfo)
@@ -289,7 +293,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	}
 }
 
-func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
 	n.hostMap.RLock()
 	defer n.hostMap.RUnlock()
 
@@ -356,6 +360,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
 		return deleteTunnel, hostinfo, nil
 	}
 
+	decision := doNothing
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
@@ -380,7 +385,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
 		}
 
 		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		decision = sendTestPacket
 
 	} else {
 		if n.l.Level >= logrus.DebugLevel {
@@ -390,7 +395,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
 
 	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
 	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
-	return doNothing, nil, nil
+	return decision, hostinfo, nil
 }
 
 func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
@@ -432,7 +437,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	if !n.intf.disconnectInvalid && err != cert.ErrBlockListed {
+	if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
 		// Block listed certificates should always be disconnected
 		return false
 	}

+ 12 - 11
connection_manager_test.go

@@ -21,8 +21,9 @@ var vpnIp iputil.VpnIp
 
 func newTestLighthouse() *LightHouse {
 	lh := &LightHouse{
-		l:       test.NewLogger(),
-		addrMap: map[iputil.VpnIp]*RemoteList{},
+		l:         test.NewLogger(),
+		addrMap:   map[iputil.VpnIp]*RemoteList{},
+		queryChan: make(chan iputil.VpnIp, 10),
 	}
 	lighthouses := map[iputil.VpnIp]struct{}{}
 	staticList := map[iputil.VpnIp]struct{}{}
@@ -253,18 +254,18 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 
 	lh := newTestLighthouse()
 	ifce := &Interface{
-		hostMap:           hostMap,
-		inside:            &test.NoopTun{},
-		outside:           &udp.NoopConn{},
-		firewall:          &Firewall{},
-		lightHouse:        lh,
-		handshakeManager:  NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
-		l:                 l,
-		disconnectInvalid: true,
-		pki:               &PKI{},
+		hostMap:          hostMap,
+		inside:           &test.NoopTun{},
+		outside:          &udp.NoopConn{},
+		firewall:         &Firewall{},
+		lightHouse:       lh,
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		l:                l,
+		pki:              &PKI{},
 	}
 	ifce.pki.cs.Store(cs)
 	ifce.pki.caPool.Store(ncp)
+	ifce.disconnectInvalid.Store(true)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())

+ 0 - 3
connection_state.go

@@ -24,7 +24,6 @@ type ConnectionState struct {
 	messageCounter atomic.Uint64
 	window         *Bits
 	writeLock      sync.Mutex
-	ready          bool
 }
 
 func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
@@ -71,7 +70,6 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		initiator: initiator,
 		window:    b,
-		ready:     false,
 		myCert:    certState.Certificate,
 	}
 
@@ -83,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"certificate":     cs.peerCert,
 		"initiator":       cs.initiator,
 		"message_counter": cs.messageCounter.Load(),
-		"ready":           cs.ready,
 	})
 }

+ 10 - 2
control.go

@@ -11,6 +11,7 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
+	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
 
@@ -29,6 +30,7 @@ type controlHostLister interface {
 type Control struct {
 	f               *Interface
 	l               *logrus.Logger
+	ctx             context.Context
 	cancel          context.CancelFunc
 	sshStart        func()
 	statsStart      func()
@@ -41,7 +43,6 @@ type ControlHostInfo struct {
 	LocalIndex             uint32                  `json:"localIndex"`
 	RemoteIndex            uint32                  `json:"remoteIndex"`
 	RemoteAddrs            []*udp.Addr             `json:"remoteAddrs"`
-	CachedPackets          int                     `json:"cachedPackets"`
 	Cert                   *cert.NebulaCertificate `json:"cert"`
 	MessageCounter         uint64                  `json:"messageCounter"`
 	CurrentRemote          *udp.Addr               `json:"currentRemote"`
@@ -72,6 +73,10 @@ func (c *Control) Start() {
 	c.f.run()
 }
 
+func (c *Control) Context() context.Context {
+	return c.ctx
+}
+
 // Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete
 func (c *Control) Stop() {
 	// Stop the handshakeManager (and other services), to prevent new tunnels from
@@ -227,6 +232,10 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	return
 }
 
+func (c *Control) Device() overlay.Device {
+	return c.f.inside
+}
+
 func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 
 	chi := ControlHostInfo{
@@ -234,7 +243,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
-		CachedPackets:          len(h.packetStore),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
 	}

+ 1 - 2
control_test.go

@@ -96,7 +96,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteAddrs:            []*udp.Addr{remote2, remote1},
-		CachedPackets:          0,
 		Cert:                   crt.Copy(),
 		MessageCounter:         0,
 		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
@@ -105,7 +104,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet

+ 18 - 18
e2e/handshakes_test.go

@@ -20,7 +20,7 @@ import (
 )
 
 func BenchmarkHotPath(b *testing.B) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -44,7 +44,7 @@ func BenchmarkHotPath(b *testing.B) {
 }
 
 func TestGoodHandshake(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -95,7 +95,7 @@ func TestGoodHandshake(t *testing.T) {
 }
 
 func TestWrongResponderHandshake(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
@@ -164,7 +164,7 @@ func TestStage1Race(t *testing.T) {
 	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
 	// But will eventually collapse down to a single tunnel
 
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -241,7 +241,7 @@ func TestStage1Race(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -290,7 +290,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
@@ -341,7 +341,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 }
 
 func TestRelays(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -372,7 +372,7 @@ func TestRelays(t *testing.T) {
 
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -421,7 +421,7 @@ func TestStage1RaceRelays(t *testing.T) {
 
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -508,7 +508,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 	////TODO: assert hostmaps
 }
 func TestRehandshakingRelays(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -538,7 +538,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -612,7 +612,7 @@ func TestRehandshakingRelays(t *testing.T) {
 
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
 	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
@@ -642,7 +642,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -715,7 +715,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 }
 
 func TestRehandshaking(t *testing.T) {
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
 
@@ -737,7 +737,7 @@ func TestRehandshaking(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew my certificate and spin until their sees it")
-	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -811,7 +811,7 @@ func TestRehandshaking(t *testing.T) {
 func TestRehandshakingLoser(t *testing.T) {
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// Should be the one with the new certificate
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
 
@@ -837,7 +837,7 @@ func TestRehandshakingLoser(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew their certificate and spin until mine sees it")
-	_, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
+	_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
 
 	caB, err := ca.MarshalToPEM()
 	if err != nil {
@@ -912,7 +912,7 @@ func TestRaceRegression(t *testing.T) {
 	// This test forces stage 1, stage 2, stage 1 to be received by me from them
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// caused a cross-linked hostinfo
-	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 

+ 118 - 0
e2e/helpers.go

@@ -0,0 +1,118 @@
+package e2e
+
+import (
+	"crypto/rand"
+	"io"
+	"net"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will generate a CA cert
+func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	nc := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           "test ca",
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           true,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	if len(ips) > 0 {
+		nc.Details.Ips = ips
+	}
+
+	if len(subnets) > 0 {
+		nc.Details.Subnets = subnets
+	}
+
+	if len(groups) > 0 {
+		nc.Details.Groups = groups
+	}
+
+	err = nc.Sign(cert.Curve_CURVE25519, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := nc.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return nc, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+	issuer, err := ca.Sha256Sum()
+	if err != nil {
+		panic(err)
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	pub, rawPriv := x25519Keypair()
+
+	nc := &cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           name,
+			Ips:            []*net.IPNet{ip},
+			Subnets:        subnets,
+			Groups:         groups,
+			NotBefore:      time.Unix(before.Unix(), 0),
+			NotAfter:       time.Unix(after.Unix(), 0),
+			PublicKey:      pub,
+			IsCA:           false,
+			Issuer:         issuer,
+			InvertedGroups: make(map[string]struct{}),
+		},
+	}
+
+	err = nc.Sign(ca.Details.Curve, key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := nc.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
+}
+
+func x25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}

+ 1 - 110
e2e/helpers_test.go

@@ -4,7 +4,6 @@
 package e2e
 
 import (
-	"crypto/rand"
 	"fmt"
 	"io"
 	"net"
@@ -22,8 +21,6 @@ import (
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
 	"gopkg.in/yaml.v2"
 )
 
@@ -40,7 +37,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		IP:   udpIp,
 		Port: 4242,
 	}
-	_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 
 	caB, err := caCrt.MarshalToPEM()
 	if err != nil {
@@ -108,112 +105,6 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 	return control, vpnIpNet, &udpAddr, c
 }
 
-// newTestCaCert will generate a CA cert
-func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(cert.Curve_CURVE25519, priv)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, priv, pem
-}
-
-// newTestCert will generate a signed certificate with the provided details.
-// Expiry times are defaulted if you do not pass them in
-func newTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		panic(err)
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	pub, rawPriv := x25519Keypair()
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           name,
-			Ips:            []*net.IPNet{ip},
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
 type doneCb func()
 
 func deadline(t *testing.T, seconds time.Duration) doneCb {

+ 5 - 1
examples/config.yml

@@ -11,7 +11,7 @@ pki:
   #blocklist:
   #  - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
-  #disconnect_invalid: false
+  #disconnect_invalid: true
 
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
@@ -330,6 +330,10 @@ logging:
   # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
   #try_interval: 100ms
   #retries: 20
+
+  # query_buffer is the size of the buffer channel for querying lighthouses
+  #query_buffer: 64
+
   # trigger_buffer is the size of the buffer channel for quickly sending handshakes
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64

+ 100 - 0
examples/go_service/main.go

@@ -0,0 +1,100 @@
+package main
+
+import (
+	"bufio"
+	"fmt"
+	"log"
+
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/service"
+)
+
+func main() {
+	if err := run(); err != nil {
+		log.Fatalf("%+v", err)
+	}
+}
+
+func run() error {
+	configStr := `
+tun:
+  user: true
+
+static_host_map:
+  '192.168.100.1': ['localhost:4242']
+
+listen:
+  host: 0.0.0.0
+  port: 4241
+
+lighthouse:
+  am_lighthouse: false
+  interval: 60
+  hosts:
+    - '192.168.100.1'
+
+firewall:
+  outbound:
+    # Allow all outbound traffic from this node
+    - port: any
+      proto: any
+      host: any
+
+  inbound:
+    # Allow icmp between any nebula hosts
+    - port: any
+      proto: icmp
+      host: any
+    - port: any
+      proto: any
+      host: any
+
+pki:
+  ca: /home/rice/Developer/nebula-config/ca.crt
+  cert: /home/rice/Developer/nebula-config/app.crt
+  key: /home/rice/Developer/nebula-config/app.key
+`
+	var config config.C
+	if err := config.LoadString(configStr); err != nil {
+		return err
+	}
+	service, err := service.New(&config)
+	if err != nil {
+		return err
+	}
+
+	ln, err := service.Listen("tcp", ":1234")
+	if err != nil {
+		return err
+	}
+	for {
+		conn, err := ln.Accept()
+		if err != nil {
+			log.Printf("accept error: %s", err)
+			break
+		}
+		defer conn.Close()
+
+		log.Printf("got connection")
+
+		conn.Write([]byte("hello world\n"))
+
+		scanner := bufio.NewScanner(conn)
+		for scanner.Scan() {
+			message := scanner.Text()
+			fmt.Fprintf(conn, "echo: %q\n", message)
+			log.Printf("got message %q", message)
+		}
+
+		if err := scanner.Err(); err != nil {
+			log.Printf("scanner error: %s", err)
+			break
+		}
+	}
+
+	service.Close()
+	if err := service.Wait(); err != nil {
+		return err
+	}
+	return nil
+}

+ 36 - 14
firewall.go

@@ -6,6 +6,7 @@ import (
 	"encoding/hex"
 	"errors"
 	"fmt"
+	"hash/fnv"
 	"net"
 	"reflect"
 	"strconv"
@@ -57,7 +58,7 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps *cidr.Tree4
+	localIps *cidr.Tree4[struct{}]
 
 	rules        string
 	rulesVersion uint16
@@ -110,8 +111,8 @@ type FirewallRule struct {
 	Any       bool
 	Hosts     map[string]struct{}
 	Groups    [][]string
-	CIDR      *cidr.Tree4
-	LocalCIDR *cidr.Tree4
+	CIDR      *cidr.Tree4[struct{}]
+	LocalCIDR *cidr.Tree4[struct{}]
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
@@ -137,7 +138,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 	}
 
-	localIps := cidr.NewTree4()
+	localIps := cidr.NewTree4[struct{}]()
 	for _, ip := range c.Details.Ips {
 		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
@@ -278,6 +279,18 @@ func (f *Firewall) GetRuleHash() string {
 	return hex.EncodeToString(sum[:])
 }
 
+// GetRuleHashFNV returns a uint32 FNV-1 hash representation the rules, for use as a metric value
+func (f *Firewall) GetRuleHashFNV() uint32 {
+	h := fnv.New32a()
+	h.Write([]byte(f.rules))
+	return h.Sum32()
+}
+
+// GetRuleHashes returns both the sha256 and FNV-1 hashes, suitable for logging
+func (f *Firewall) GetRuleHashes() string {
+	return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
+}
+
 func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
 	var table string
 	if inbound {
@@ -391,7 +404,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
 
 	// Make sure remote address matches nebula certificate
 	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		if remoteCidr.Contains(fp.RemoteIP) == nil {
+		ok, _ := remoteCidr.Contains(fp.RemoteIP)
+		if !ok {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
 		}
@@ -404,7 +418,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	if f.localIps.Contains(fp.LocalIP) == nil {
+	ok, _ := f.localIps.Contains(fp.LocalIP)
+	if !ok {
 		f.metrics(incoming).droppedLocalIP.Inc(1)
 		return ErrInvalidLocalIP
 	}
@@ -447,6 +462,7 @@ func (f *Firewall) EmitStats() {
 	conntrack.Unlock()
 	metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
 	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
+	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 
 func (f *Firewall) inConns(packet []byte, fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
@@ -657,8 +673,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
 		return &FirewallRule{
 			Hosts:     make(map[string]struct{}),
 			Groups:    make([][]string, 0),
-			CIDR:      cidr.NewTree4(),
-			LocalCIDR: cidr.NewTree4(),
+			CIDR:      cidr.NewTree4[struct{}](),
+			LocalCIDR: cidr.NewTree4[struct{}](),
 		}
 	}
 
@@ -726,8 +742,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
-		fr.CIDR = cidr.NewTree4()
-		fr.LocalCIDR = cidr.NewTree4()
+		fr.CIDR = cidr.NewTree4[struct{}]()
+		fr.LocalCIDR = cidr.NewTree4[struct{}]()
 	} else {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
@@ -809,12 +825,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		}
 	}
 
-	if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
-		return true
+	if fr.CIDR != nil {
+		ok, _ := fr.CIDR.Contains(p.RemoteIP)
+		if ok {
+			return true
+		}
 	}
 
-	if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
-		return true
+	if fr.LocalCIDR != nil {
+		ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
+		if ok {
+			return true
+		}
 	}
 
 	// No host, group, or cidr matched, bye bye

+ 8 - 4
firewall_test.go

@@ -92,14 +92,16 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
@@ -114,8 +116,10 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
+	ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion

+ 14 - 11
go.mod

@@ -7,29 +7,31 @@ require (
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
-	github.com/flynn/noise v1.0.0
+	github.com/flynn/noise v1.0.1
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
 	github.com/miekg/dns v1.1.56
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.16.0
+	github.com/prometheus/client_golang v1.17.0
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 	github.com/stretchr/testify v1.8.4
-	github.com/vishvananda/netlink v1.1.0
-	golang.org/x/crypto v0.14.0
+	github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
+	golang.org/x/crypto v0.17.0
 	golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53
-	golang.org/x/net v0.17.0
-	golang.org/x/sys v0.13.0
-	golang.org/x/term v0.13.0
+	golang.org/x/net v0.19.0
+	golang.org/x/sync v0.5.0
+	golang.org/x/sys v0.15.0
+	golang.org/x/term v0.15.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
 	google.golang.org/protobuf v1.31.0
 	gopkg.in/yaml.v2 v2.4.0
+	gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f
 )
 
 require (
@@ -37,14 +39,15 @@ require (
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/golang/protobuf v1.5.3 // indirect
+	github.com/google/btree v1.0.1 // indirect
 	github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.4.0 // indirect
-	github.com/prometheus/common v0.42.0 // indirect
-	github.com/prometheus/procfs v0.10.1 // indirect
-	github.com/rogpeppe/go-internal v1.10.0 // indirect
+	github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
+	github.com/prometheus/common v0.44.0 // indirect
+	github.com/prometheus/procfs v0.11.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
 	golang.org/x/mod v0.12.0 // indirect
+	golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
 	golang.org/x/tools v0.13.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 31 - 24
go.sum

@@ -22,8 +22,8 @@ github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/flynn/noise v1.0.0 h1:DlTHqmzmvcEiKj+4RYo/imoswx/4r6iBlCMfVtrMXpQ=
-github.com/flynn/noise v1.0.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/flynn/noise v1.0.1 h1:vPp/jdQLXC6ppsXSj/pM3W1BIJ5FEHE2TulSJBpb43Y=
+github.com/flynn/noise v1.0.1/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -47,6 +47,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
 github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
+github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -97,28 +99,27 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.16.0 h1:yk/hx9hDbrGHovbci4BY+pRMfSuuat626eFsHb7tmT8=
-github.com/prometheus/client_golang v1.16.0/go.mod h1:Zsulrv/L9oM40tJ7T815tM89lFEugiJ9HzIqaAx4LKc=
+github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
+github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY=
-github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
+github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
+github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.42.0 h1:EKsfXEYo4JpWMHH5cg+KOUWeuJSov1Id8zGR8eeI1YM=
-github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc=
+github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
+github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg=
-github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
+github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
+github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
-github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
 github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
 github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -136,9 +137,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
-github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
-github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
+github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So=
+github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
+github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -148,8 +149,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
-golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
+golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
+golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 h1:5llv2sWeaMSnA3w2kS57ouQQ4pudlXrR0dCgw51QK9o=
 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -168,8 +169,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
-golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
+golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
+golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -177,31 +178,35 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
+golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
+golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
-golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
+golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
-golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
+golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
+golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
@@ -245,3 +250,5 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f h1:8GE2MRjGiFmfpon8dekPI08jEuNMQzSffVHgdupcO4E=
+gvisor.dev/gvisor v0.0.0-20230504175454-7b0a1988a28f/go.mod h1:pzr6sy8gDLfVmDAg8OYrlKvGEHw5C3PGTiBXBTCx76Q=

+ 0 - 31
handshake.go

@@ -1,31 +0,0 @@
-package nebula
-
-import (
-	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
-)
-
-func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
-	// First remote allow list check before we know the vpnIp
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
-			f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
-			return
-		}
-	}
-
-	switch h.Subtype {
-	case header.HandshakeIXPSK0:
-		switch h.MessageCounter {
-		case 1:
-			ixHandshakeStage1(f, addr, via, packet, h)
-		case 2:
-			newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
-			tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
-			if tearDown && newHostinfo != nil {
-				f.handshakeManager.DeleteHostInfo(newHostinfo)
-			}
-		}
-	}
-
-}

+ 47 - 58
handshake_ix.go

@@ -4,6 +4,7 @@ import (
 	"time"
 
 	"github.com/flynn/noise"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
@@ -13,20 +14,20 @@ import (
 
 // This function constructs a handshake packet, but does not actually send it
 // Sending is done by the handshake manager
-func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
-	err := f.handshakeManager.allocateIndex(hostinfo)
+func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
+	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 	}
 
 	certState := f.pki.GetCertState()
 	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
-	hostinfo.ConnectionState = ci
+	hh.hostinfo.ConnectionState = ci
 
 	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hostinfo.localIndexId,
+		InitiatorIndex: hh.hostinfo.localIndexId,
 		Time:           uint64(time.Now().UnixNano()),
 		Cert:           certState.RawCertificateNoKey,
 	}
@@ -48,7 +49,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
 	hsBytes, err = hs.Marshal()
 
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 	}
@@ -58,7 +59,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 	}
@@ -67,9 +68,8 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
 	// handshake packet 1 from the responder
 	ci.window.Update(f.l, 1)
 
-	hostinfo.HandshakePacket[0] = msg
-	hostinfo.HandshakeReady = true
-	hostinfo.handshakeStart = time.Now()
+	hh.hostinfo.HandshakePacket[0] = msg
+	hh.ready = true
 	return true
 }
 
@@ -174,9 +174,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		},
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -243,19 +240,16 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	if err != nil {
 		switch err {
 		case ErrAlreadySeen:
-			// Update remote if preferred (Note we have to switch to locking
-			// the existing hostinfo, and then switch back so the defer Unlock
-			// higher in this function still works)
-			hostinfo.Unlock()
-			existing.Lock()
+			if hostinfo.multiportRx {
+				// The other host is sending to us with multiport, so only grab the IP
+				addr.Port = hostinfo.remote.Port
+			}
 			// Update remote if preferred
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
-			existing.Unlock()
-			hostinfo.Lock()
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
@@ -356,7 +350,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				WithField("sentCachedPackets", len(hostinfo.packetStore)).
 				Info("Handshake message sent")
 		}
 	} else {
@@ -372,25 +365,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 			WithField("issuer", issuer).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-			WithField("sentCachedPackets", len(hostinfo.packetStore)).
 			Info("Handshake message sent")
 	}
 
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
+	hostinfo.ConnectionState.messageCounter.Store(2)
+	hostinfo.remotes.ResetBlockedRemotes()
 
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
-	if hostinfo == nil {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
+	hh.Lock()
+	defer hh.Unlock()
 
+	hostinfo := hh.hostinfo
 	if addr != nil {
 		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@@ -399,27 +393,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	}
 
 	ci := hostinfo.ConnectionState
-	if ci.ready {
-		if hostinfo.multiportRx {
-			// The other host is sending to us with multiport, so only grab the IP
-			addr.Port = hostinfo.remote.Port
-		}
-
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
-			Info("Handshake is already complete")
-
-		// Update remote if preferred
-		if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
-			// Send a test packet to ensure the other side has also switched to
-			// the preferred remote
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
-		}
-
-		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
-		return false
-	}
-
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
@@ -490,22 +463,22 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
+		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
 			//TODO: this doesnt know if its being added or is being used for caching a packet
 			// Block the current used address
-			newHostInfo.remotes = hostinfo.remotes
-			newHostInfo.remotes.BlockRemote(addr)
+			newHH.hostinfo.remotes = hostinfo.remotes
+			newHH.hostinfo.remotes.BlockRemote(addr)
 
 			// Get the correct remote list for the host we did handshake with
 			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 
-			f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
-				WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
+			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 				Info("Blocked addresses for handshakes")
 
 			// Swap the packet store to benefit the original intended recipient
-			newHostInfo.packetStore = hostinfo.packetStore
-			hostinfo.packetStore = []*cachedPacket{}
+			newHH.packetStore = hh.packetStore
+			hh.packetStore = []*cachedPacket{}
 
 			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
 			hostinfo.vpnIp = vpnIp
@@ -518,7 +491,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	// Mark packet 2 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 2)
 
-	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
+	duration := time.Since(hh.startTime).Nanoseconds()
 	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -526,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 		WithField("durationNs", duration).
-		WithField("sentCachedPackets", len(hostinfo.packetStore)).
+		WithField("sentCachedPackets", len(hh.packetStore)).
 		WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx).
 		Info("Handshake message received")
 
@@ -551,7 +524,23 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
+
+	hostinfo.ConnectionState.messageCounter.Store(2)
+
+	if f.l.Level >= logrus.DebugLevel {
+		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
+	}
+
+	if len(hh.packetStore) > 0 {
+		nb := make([]byte, 12, 12)
+		out := make([]byte, mtu)
+		for _, cp := range hh.packetStore {
+			cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
+		}
+		f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
+	}
+
+	hostinfo.remotes.ResetBlockedRemotes()
 	f.metricHandshakes.Update(duration)
 
 	return false

+ 170 - 94
handshake_manager.go

@@ -46,8 +46,8 @@ type HandshakeManager struct {
 	// Mutex for interacting with the vpnIps and indexes maps
 	sync.RWMutex
 
-	vpnIps  map[iputil.VpnIp]*HostInfo
-	indexes map[uint32]*HostInfo
+	vpnIps  map[iputil.VpnIp]*HandshakeHostInfo
+	indexes map[uint32]*HandshakeHostInfo
 
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
@@ -67,10 +67,47 @@ type HandshakeManager struct {
 	trigger chan iputil.VpnIp
 }
 
+type HandshakeHostInfo struct {
+	sync.Mutex
+
+	startTime   time.Time       // Time that we first started trying with this handshake
+	ready       bool            // Is the handshake ready
+	counter     int             // How many attempts have we made so far
+	lastRemotes []*udp.Addr     // Remotes that we sent to during the previous attempt
+	packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+
+	hostinfo *HostInfo
+}
+
+func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
+	if len(hh.packetStore) < 100 {
+		tempPacket := make([]byte, len(packet))
+		copy(tempPacket, packet)
+
+		hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
+		if l.Level >= logrus.DebugLevel {
+			hh.hostinfo.logger(l).
+				WithField("length", len(hh.packetStore)).
+				WithField("stored", true).
+				Debugf("Packet store")
+		}
+
+	} else {
+		m.dropped.Inc(1)
+
+		if l.Level >= logrus.DebugLevel {
+			hh.hostinfo.logger(l).
+				WithField("length", len(hh.packetStore)).
+				WithField("stored", false).
+				Debugf("Packet store")
+		}
+	}
+}
+
 func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		vpnIps:                 map[iputil.VpnIp]*HostInfo{},
-		indexes:                map[uint32]*HostInfo{},
+		vpnIps:                 map[iputil.VpnIp]*HandshakeHostInfo{},
+		indexes:                map[uint32]*HandshakeHostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
@@ -100,6 +137,31 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 	}
 }
 
+func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+	// First remote allow list check before we know the vpnIp
+	if addr != nil {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+			return
+		}
+	}
+
+	switch h.Subtype {
+	case header.HandshakeIXPSK0:
+		switch h.MessageCounter {
+		case 1:
+			ixHandshakeStage1(hm.f, addr, via, packet, h)
+
+		case 2:
+			newHostinfo := hm.queryIndex(h.RemoteIndex)
+			tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
+			if tearDown && newHostinfo != nil {
+				hm.DeleteHostInfo(newHostinfo.hostinfo)
+			}
+		}
+	}
+}
+
 func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	c.OutboundHandshakeTimer.Advance(now)
 	for {
@@ -111,41 +173,35 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
-	hostinfo := c.QueryVpnIp(vpnIp)
-	if hostinfo == nil {
-		return
-	}
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
-	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
-	if hostinfo.HandshakeComplete {
-		// Ensure we don't exist in the pending hostmap anymore since we have completed
-		c.DeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+	hh := hm.queryVpnIp(vpnIp)
+	if hh == nil {
 		return
 	}
+	hh.Lock()
+	defer hh.Unlock()
 
+	hostinfo := hh.hostinfo
 	// If we are out of time, clean up
-	if hostinfo.HandshakeCounter >= c.config.retries {
-		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
-			WithField("initiatorIndex", hostinfo.localIndexId).
-			WithField("remoteIndex", hostinfo.remoteIndexId).
+	if hh.counter >= hm.config.retries {
+		hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
+			WithField("initiatorIndex", hh.hostinfo.localIndexId).
+			WithField("remoteIndex", hh.hostinfo.remoteIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
+			WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
 			Info("Handshake timed out")
-		c.metricTimedOut.Inc(1)
-		c.DeleteHostInfo(hostinfo)
+		hm.metricTimedOut.Inc(1)
+		hm.DeleteHostInfo(hostinfo)
 		return
 	}
 
 	// Increment the counter to increase our delay, linear backoff
-	hostinfo.HandshakeCounter++
+	hh.counter++
 
 	// Check if we have a handshake packet to transmit yet
-	if !hostinfo.HandshakeReady {
-		if !ixHandshakeStage0(c.f, hostinfo) {
-			c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+	if !hh.ready {
+		if !ixHandshakeStage0(hm.f, hh) {
+			hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
 			return
 		}
 	}
@@ -155,11 +211,11 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
 	}
 
-	remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
-	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
+	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
+	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// This is a very specific optimization for a fast lighthouse reply.
@@ -168,26 +224,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 		return
 	}
 
-	hostinfo.HandshakeLastRemotes = remotes
+	hh.lastRemotes = remotes
 
 	// TODO: this will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
-	if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 {
+	if len(remotes) <= 1 && hh.counter == 5 {
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
 		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		c.lightHouse.QueryServer(vpnIp, c.f)
+		hm.lightHouse.QueryServer(vpnIp)
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
 	var sentMultiport bool
-	hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
-		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-		err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+	hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
-			hostinfo.logger(c.l).WithField("udpAddr", addr).
+			hostinfo.logger(hm.l).WithField("udpAddr", addr).
 				WithField("initiatorIndex", hostinfo.localIndexId).
 				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake message")
@@ -197,7 +253,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 		}
 
 		// Attempt a multiport handshake if we are past the TxHandshakeDelay attempts
-		if c.multiPort.TxHandshake && c.udpRaw != nil && hostinfo.HandshakeCounter >= c.multiPort.TxHandshakeDelay {
+		if hm.multiPort.TxHandshake && hm.udpRaw != nil && hh.counter >= hm.multiPort.TxHandshakeDelay {
 			sentMultiport = true
 			// We need to re-allocate with 8 bytes at the start of SOCK_RAW
 			raw := hostinfo.HandshakePacket[0x80]
@@ -207,10 +263,10 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 				hostinfo.HandshakePacket[0x80] = raw
 			}
 
-			c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-			err = c.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(c.multiPort.TxPorts), addr)
+			hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+			err = hm.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(hm.multiPort.TxPorts), addr)
 			if err != nil {
-				hostinfo.logger(c.l).WithField("udpAddr", addr).
+				hostinfo.logger(hm.l).WithField("udpAddr", addr).
 					WithField("initiatorIndex", hostinfo.localIndexId).
 					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 					WithError(err).Error("Failed to send handshake message")
@@ -221,64 +277,64 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 	// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
 	// so only log when the list of remotes has changed
 	if remotesHaveChanged {
-		hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			WithField("multiportHandshake", sentMultiport).
 			Info("Handshake message sent")
-	} else if c.l.IsLevelEnabled(logrus.DebugLevel) {
-		hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+	} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
+		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Debug("Handshake message sent")
 	}
 
-	if c.config.useRelays && len(hostinfo.remotes.relays) > 0 {
-		hostinfo.logger(c.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
+	if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
+		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
+			if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
 				continue
 			}
-			relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
+			relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
 			if relayHostInfo == nil || relayHostInfo.remote == nil {
-				hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				c.f.Handshake(*relay)
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
+				hm.f.Handshake(*relay)
 				continue
 			}
 			// Check the relay HostInfo to see if we already established a relay through it
 			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
 				switch existingRelay.State {
 				case Established:
-					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 				case Requested:
-					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 					// Re-send the CreateRelay request, in case the previous one was lost.
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         uint32(c.lightHouse.myVpnIp),
+						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
 						RelayToIp:           uint32(vpnIp),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
-						hostinfo.logger(c.l).
+						hostinfo.logger(hm.l).
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
 						// This must send over the hostinfo, not over hm.Hosts[ip]
-						c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						c.l.WithFields(logrus.Fields{
-							"relayFrom":           c.lightHouse.myVpnIp,
+						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						hm.l.WithFields(logrus.Fields{
+							"relayFrom":           hm.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
 							"relay":               *relay}).
 							Info("send CreateRelayRequest")
 					}
 				default:
-					hostinfo.logger(c.l).
+					hostinfo.logger(hm.l).
 						WithField("vpnIp", vpnIp).
 						WithField("state", existingRelay.State).
 						WithField("relay", relayHostInfo.vpnIp).
@@ -287,26 +343,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 			} else {
 				// No relays exist or requested yet.
 				if relayHostInfo.remote != nil {
-					idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested)
+					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					if err != nil {
-						hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
+						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         uint32(c.lightHouse.myVpnIp),
+						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
 						RelayToIp:           uint32(vpnIp),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
-						hostinfo.logger(c.l).
+						hostinfo.logger(hm.l).
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						c.l.WithFields(logrus.Fields{
-							"relayFrom":           c.lightHouse.myVpnIp,
+						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						hm.l.WithFields(logrus.Fields{
+							"relayFrom":           hm.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               *relay}).
@@ -319,13 +375,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
 	}
 }
 
 // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
 // The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) {
+func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	// Check the main hostmap and maintain a read lock if our host is not there
 	hm.mainHostMap.RLock()
 	if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
@@ -342,16 +398,16 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos
 }
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
+	defer hm.Unlock()
 
-	if hostinfo, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnIp]; ok {
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
-			cacheCb(hostinfo)
+			cacheCb(hh)
 		}
-		hm.Unlock()
-		return hostinfo
+		return hh.hostinfo
 	}
 
 	hostinfo := &HostInfo{
@@ -364,12 +420,16 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos
 		},
 	}
 
-	hm.vpnIps[vpnIp] = hostinfo
+	hh := &HandshakeHostInfo{
+		hostinfo:  hostinfo,
+		startTime: time.Now(),
+	}
+	hm.vpnIps[vpnIp] = hh
 	hm.metricInitiated.Inc(1)
 	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
 
 	if cacheCb != nil {
-		cacheCb(hostinfo)
+		cacheCb(hh)
 	}
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
@@ -387,8 +447,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos
 		}
 	}
 
-	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp, hm.f)
+	hm.lightHouse.QueryServer(vpnIp)
 	return hostinfo
 }
 
@@ -442,8 +501,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingIndex, found = c.indexes[hostinfo.localIndexId]
-	if found && existingIndex != hostinfo {
+	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
@@ -487,7 +546,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 // allocateIndex generates a unique localIndexId for this HostInfo
 // and adds it to the pendingHostMap. Will error if we are unable to generate
 // a unique localIndexId
-func (hm *HandshakeManager) allocateIndex(h *HostInfo) error {
+func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	hm.mainHostMap.RLock()
 	defer hm.mainHostMap.RUnlock()
 	hm.Lock()
@@ -503,8 +562,8 @@ func (hm *HandshakeManager) allocateIndex(h *HostInfo) error {
 		_, inMain := hm.mainHostMap.Indexes[index]
 
 		if !inMain && !inPending {
-			h.localIndexId = index
-			hm.indexes[index] = h
+			hh.hostinfo.localIndexId = index
+			hm.indexes[index] = hh
 			return nil
 		}
 	}
@@ -521,12 +580,12 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
 func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	delete(c.vpnIps, hostinfo.vpnIp)
 	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[iputil.VpnIp]*HostInfo{}
+		c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
 	}
 
 	delete(c.indexes, hostinfo.localIndexId)
 	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HostInfo{}
+		c.indexes = map[uint32]*HandshakeHostInfo{}
 	}
 
 	if c.l.Level >= logrus.DebugLevel {
@@ -536,16 +595,33 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 }
 
-func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
-	c.RLock()
-	defer c.RUnlock()
-	return c.vpnIps[vpnIp]
+func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	hh := hm.queryVpnIp(vpnIp)
+	if hh != nil {
+		return hh.hostinfo
+	}
+	return nil
+
 }
 
-func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
-	c.RLock()
-	defer c.RUnlock()
-	return c.indexes[index]
+func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+	hm.RLock()
+	defer hm.RUnlock()
+	return hm.vpnIps[vpnIp]
+}
+
+func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo {
+	hh := hm.queryIndex(index)
+	if hh != nil {
+		return hh.hostinfo
+	}
+	return nil
+}
+
+func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
+	hm.RLock()
+	defer hm.RUnlock()
+	return hm.indexes[index]
 }
 
 func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
@@ -557,7 +633,7 @@ func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
 	defer c.RUnlock()
 
 	for _, v := range c.vpnIps {
-		f(v)
+		f(v.hostinfo)
 	}
 }
 
@@ -566,7 +642,7 @@ func (c *HandshakeManager) ForEachIndex(f controlEach) {
 	defer c.RUnlock()
 
 	for _, v := range c.indexes {
-		f(v)
+		f(v.hostinfo)
 	}
 }
 

+ 10 - 1
handshake_manager_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
@@ -21,7 +22,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	mainHM := NewHostMap(l, vpncidr, preferredRanges)
 	lh := newTestLighthouse()
 
+	cs := &CertState{
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
+	}
+
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
+	blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l}
+	blah.f.pki.cs.Store(cs)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now)
@@ -31,7 +41,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.Same(t, i, i2)
 
 	i.remotes = NewRemoteList(nil)
-	i.HandshakeReady = true
 
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)

+ 25 - 82
hostmap.go

@@ -21,6 +21,7 @@ const defaultPromoteEvery = 1000       // Count of packets sent before we try mo
 const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
 const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
+const maxRecvError = 4
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
 // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -196,27 +197,26 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
 }
 
 type HostInfo struct {
-	sync.RWMutex
-
-	remote               *udp.Addr
-	remotes              *RemoteList
-	promoteCounter       atomic.Uint32
-	multiportTx          bool
-	multiportRx          bool
-	ConnectionState      *ConnectionState
-	handshakeStart       time.Time   //todo: this an entry in the handshake manager
-	HandshakeReady       bool        //todo: being in the manager means you are ready
-	HandshakeCounter     int         //todo: another handshake manager entry
-	HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
-	HandshakeComplete    bool        //todo: this should go away in favor of ConnectionState.ready
-	HandshakePacket      map[uint8][]byte
-	packetStore          []*cachedPacket //todo: this is other handshake manager entry
-	remoteIndexId        uint32
-	localIndexId         uint32
-	vpnIp                iputil.VpnIp
-	recvError            int
-	remoteCidr           *cidr.Tree4
-	relayState           RelayState
+	remote          *udp.Addr
+	remotes         *RemoteList
+	promoteCounter  atomic.Uint32
+	ConnectionState *ConnectionState
+	remoteIndexId   uint32
+	localIndexId    uint32
+	vpnIp           iputil.VpnIp
+	recvError       atomic.Uint32
+	remoteCidr      *cidr.Tree4[struct{}]
+	relayState      RelayState
+
+	// If true, we should send to this remote using multiport
+	multiportTx bool
+
+	// If true, we will receive from this remote using multiport
+	multiportRx bool
+
+	// HandshakePacket records the packets used to create this hostinfo
+	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
+	HandshakePacket map[uint8][]byte
 
 	// nextLHQuery is the earliest we can ask the lighthouse for new information.
 	// This is used to limit lighthouse re-queries in chatty clients
@@ -414,7 +414,6 @@ func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
 }
 
 func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
 	hm.RLock()
 	if h, ok := hm.Relays[index]; ok {
 		hm.RUnlock()
@@ -537,10 +536,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
 	if c%ifce.tryPromoteEvery.Load() == 0 {
-		// The lock here is currently protecting i.remote access
-		i.RLock()
 		remote := i.remote
-		i.RUnlock()
 
 		// return early if we are already on a preferred remote
 		if remote != nil {
@@ -571,62 +567,10 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 		}
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp, ifce)
-	}
-}
-
-func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
-	//TODO: return the error so we can log with more context
-	if len(i.packetStore) < 100 {
-		tempPacket := make([]byte, len(packet))
-		copy(tempPacket, packet)
-		//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
-		i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
-		if l.Level >= logrus.DebugLevel {
-			i.logger(l).
-				WithField("length", len(i.packetStore)).
-				WithField("stored", true).
-				Debugf("Packet store")
-		}
-
-	} else if l.Level >= logrus.DebugLevel {
-		m.dropped.Inc(1)
-		i.logger(l).
-			WithField("length", len(i.packetStore)).
-			WithField("stored", false).
-			Debugf("Packet store")
+		ifce.lightHouse.QueryServer(i.vpnIp)
 	}
 }
 
-// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
-func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
-	//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
-	//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
-	//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
-
-	i.HandshakeComplete = true
-	//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
-	// Clamping it to 2 gets us out of the woods for now
-	i.ConnectionState.messageCounter.Store(2)
-
-	if l.Level >= logrus.DebugLevel {
-		i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
-	}
-
-	if len(i.packetStore) > 0 {
-		nb := make([]byte, 12, 12)
-		out := make([]byte, mtu)
-		for _, cp := range i.packetStore {
-			cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
-		}
-		m.sent.Inc(int64(len(i.packetStore)))
-	}
-
-	i.remotes.ResetBlockedRemotes()
-	i.packetStore = make([]*cachedPacket, 0)
-	i.ConnectionState.ready = true
-}
-
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
@@ -683,9 +627,8 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 }
 
 func (i *HostInfo) RecvErrorExceeded() bool {
-	if i.recvError < 3 {
-		i.recvError += 1
-		return false
+	if i.recvError.Add(1) >= maxRecvError {
+		return true
 	}
 	return true
 }
@@ -696,7 +639,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 	}
 
-	remoteCidr := cidr.NewTree4()
+	remoteCidr := cidr.NewTree4[struct{}]()
 	for _, ip := range c.Details.Ips {
 		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}

+ 26 - 12
inside.go

@@ -44,8 +44,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) {
-		h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
 	if hostinfo == nil {
@@ -83,6 +83,10 @@ func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
 	}
 
 	out = iputil.CreateRejectPacket(packet, out)
+	if len(out) == 0 {
+		return
+	}
+
 	_, err := f.readers[q].Write(out)
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
@@ -94,12 +98,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 	}
 
-	// Use some out buffer space to build the packet before encryption
-	// Need 40 bytes for the reject packet (20 byte ipv4 header, 20 byte tcp rst packet)
-	// Leave 100 bytes for the encrypted packet (60 byte Nebula header, 40 byte reject packet)
-	out = out[:140]
-	outPacket := iputil.CreateRejectPacket(packet, out[100:])
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, outPacket, nb, out, q, nil)
+	out = iputil.CreateRejectPacket(packet, out)
+	if len(out) == 0 {
+		return
+	}
+
+	if len(out) > iputil.MaxRejectPacketSize {
+		if f.l.GetLevel() >= logrus.InfoLevel {
+			f.l.
+				WithField("packet", packet).
+				WithField("outPacket", out).
+				Info("rejectOutside: packet too big, not sending")
+		}
+		return
+	}
+
+	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q, nil)
 }
 
 func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
@@ -108,7 +122,7 @@ func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
 
 // getOrHandshake returns nil if the vpnIp is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) {
+func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
 		vpnIp = f.inside.RouteFor(vpnIp)
 		if vpnIp == 0 {
@@ -143,8 +157,8 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) {
-		h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
+	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 
 	if hostInfo == nil {
@@ -291,7 +305,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp, f)
+		f.lightHouse.QueryServer(hostinfo.vpnIp)
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")

+ 17 - 7
interface.go

@@ -40,7 +40,6 @@ type InterfaceConfig struct {
 	routines                int
 	MessageMetrics          *MessageMetrics
 	version                 string
-	disconnectInvalid       bool
 	relayManager            *relayManager
 	punchy                  *Punchy
 
@@ -69,7 +68,7 @@ type Interface struct {
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	routines           int
-	disconnectInvalid  bool
+	disconnectInvalid  atomic.Bool
 	closed             atomic.Bool
 	relayManager       *relayManager
 
@@ -188,7 +187,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		version:            c.version,
 		writers:            make([]udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
-		disconnectInvalid:  c.disconnectInvalid,
 		myVpnIp:            myVpnIp,
 		relayManager:       c.relayManager,
 
@@ -308,12 +306,24 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
+	c.RegisterReloadCallback(f.reloadDisconnectInvalid)
 	c.RegisterReloadCallback(f.reloadMisc)
+
 	for _, udpConn := range f.writers {
 		c.RegisterReloadCallback(udpConn.ReloadConfig)
 	}
 }
 
+func (f *Interface) reloadDisconnectInvalid(c *config.C) {
+	initial := c.InitialLoad()
+	if initial || c.HasChanged("pki.disconnect_invalid") {
+		f.disconnectInvalid.Store(c.GetBool("pki.disconnect_invalid", true))
+		if !initial {
+			f.l.Infof("pki.disconnect_invalid changed to %v", f.disconnectInvalid.Load())
+		}
+	}
+}
+
 func (f *Interface) reloadFirewall(c *config.C) {
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
@@ -336,8 +346,8 @@ func (f *Interface) reloadFirewall(c *config.C) {
 	// If rulesVersion is back to zero, we have wrapped all the way around. Be
 	// safe and just reset conntrack in this case.
 	if fw.rulesVersion == 0 {
-		f.l.WithField("firewallHash", fw.GetRuleHash()).
-			WithField("oldFirewallHash", oldFw.GetRuleHash()).
+		f.l.WithField("firewallHashes", fw.GetRuleHashes()).
+			WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
 			WithField("rulesVersion", fw.rulesVersion).
 			Warn("firewall rulesVersion has overflowed, resetting conntrack")
 	} else {
@@ -347,8 +357,8 @@ func (f *Interface) reloadFirewall(c *config.C) {
 	f.firewall = fw
 
 	oldFw.Destroy()
-	f.l.WithField("firewallHash", fw.GetRuleHash()).
-		WithField("oldFirewallHash", oldFw.GetRuleHash()).
+	f.l.WithField("firewallHashes", fw.GetRuleHashes()).
+		WithField("oldFirewallHashes", oldFw.GetRuleHashes()).
 		WithField("rulesVersion", fw.rulesVersion).
 		Info("New firewall has been installed")
 }

+ 34 - 7
iputil/packet.go

@@ -6,8 +6,19 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
+const (
+	// Need 96 bytes for the largest reject packet:
+	// - 20 byte ipv4 header
+	// - 8 byte icmpv4 header
+	// - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header)
+	MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8
+)
+
 func CreateRejectPacket(packet []byte, out []byte) []byte {
-	// TODO ipv4 only, need to fix when inside supports ipv6
+	if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version {
+		return nil
+	}
+
 	switch packet[9] {
 	case 6: // tcp
 		return ipv4CreateRejectTCPPacket(packet, out)
@@ -19,20 +30,28 @@ func CreateRejectPacket(packet []byte, out []byte) []byte {
 func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte {
 	ihl := int(packet[0]&0x0f) << 2
 
-	// ICMP reply includes header and first 8 bytes of the packet
+	if len(packet) < ihl {
+		// We need at least this many bytes for this to be a valid packet
+		return nil
+	}
+
+	// ICMP reply includes original header and first 8 bytes of the packet
 	packetLen := len(packet)
 	if packetLen > ihl+8 {
 		packetLen = ihl + 8
 	}
 
 	outLen := ipv4.HeaderLen + 8 + packetLen
+	if outLen > cap(out) {
+		return nil
+	}
 
-	out = out[:(outLen)]
+	out = out[:outLen]
 
 	ipHdr := out[0:ipv4.HeaderLen]
-	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)                        // version, ihl
-	ipHdr[1] = 0                                                              // DSCP, ECN
-	binary.BigEndian.PutUint16(ipHdr[2:], uint16(ipv4.HeaderLen+8+packetLen)) // Total Length
+	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl
+	ipHdr[1] = 0                                          // DSCP, ECN
+	binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length
 
 	ipHdr[4] = 0  // id
 	ipHdr[5] = 0  //  .
@@ -76,7 +95,15 @@ func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte {
 	ihl := int(packet[0]&0x0f) << 2
 	outLen := ipv4.HeaderLen + tcpLen
 
-	out = out[:(outLen)]
+	if len(packet) < ihl+tcpLen {
+		// We need at least this many bytes for this to be a valid packet
+		return nil
+	}
+	if outLen > cap(out) {
+		return nil
+	}
+
+	out = out[:outLen]
 
 	ipHdr := out[0:ipv4.HeaderLen]
 	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl

+ 73 - 0
iputil/packet_test.go

@@ -0,0 +1,73 @@
+package iputil
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"golang.org/x/net/ipv4"
+)
+
+func Test_CreateRejectPacket(t *testing.T) {
+	h := ipv4.Header{
+		Len:      20,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 1, // ICMP
+	}
+
+	b, err := h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4}...)
+
+	expectedLen := ipv4.HeaderLen + 8 + h.Len + 4
+	out := make([]byte, expectedLen)
+	rejectPacket := CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+
+	// ICMP with max header len
+	h = ipv4.Header{
+		Len:      60,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 1, // ICMP
+		Options:  make([]byte, 40),
+	}
+
+	b, err = h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4, 0, 0, 0, 0}...)
+
+	expectedLen = MaxRejectPacketSize
+	out = make([]byte, MaxRejectPacketSize)
+	rejectPacket = CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+
+	// TCP with max header len
+	h = ipv4.Header{
+		Len:      60,
+		Src:      net.IPv4(10, 0, 0, 1),
+		Dst:      net.IPv4(10, 0, 0, 2),
+		Protocol: 6, // TCP
+		Options:  make([]byte, 40),
+	}
+
+	b, err = h.Marshal()
+	if err != nil {
+		t.Fatalf("h.Marhshal: %v", err)
+	}
+	b = append(b, []byte{0, 3, 0, 4}...)
+	b = append(b, make([]byte, 16)...)
+
+	expectedLen = ipv4.HeaderLen + 20
+	out = make([]byte, expectedLen)
+	rejectPacket = CreateRejectPacket(b, out)
+	assert.NotNil(t, rejectPacket)
+	assert.Len(t, rejectPacket, expectedLen)
+}

+ 56 - 28
lighthouse.go

@@ -74,7 +74,9 @@ type LightHouse struct {
 	// IP's of relays that can be used by peers to access me
 	relaysForMe atomic.Pointer[[]iputil.VpnIp]
 
-	calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote
+	queryChan chan iputil.VpnIp
+
+	calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
 
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
@@ -110,6 +112,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		nebulaPort:   nebulaPort,
 		punchConn:    pc,
 		punchy:       p,
+		queryChan:    make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
 		l:            l,
 	}
 	lighthouses := make(map[iputil.VpnIp]struct{})
@@ -139,6 +142,8 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		}
 	})
 
+	h.startQueryWorker()
+
 	return &h, nil
 }
 
@@ -166,7 +171,7 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
 	return *lh.relaysForMe.Load()
 }
 
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 {
+func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
 	return lh.calculatedRemotes.Load()
 }
 
@@ -443,9 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	return nil
 }
 
-func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
+func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
-		lh.QueryServer(ip, f)
+		lh.QueryServer(ip)
 	}
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
@@ -456,30 +461,14 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
 	return nil
 }
 
-// This is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
-	if lh.amLighthouse {
-		return
-	}
-
-	if lh.IsLighthouseIP(ip) {
-		return
-	}
-
-	// Send a query to the lighthouses and hope for the best next time
-	query, err := NewLhQueryByInt(ip).Marshal()
-	if err != nil {
-		lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
+// QueryServer is asynchronous so no reply should be expected
+func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
+	// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
+	if lh.amLighthouse || lh.IsLighthouseIP(ip) {
 		return
 	}
 
-	lighthouses := lh.GetLighthouses()
-	lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
-	nb := make([]byte, 12, 12)
-	out := make([]byte, mtu)
-	for n := range lighthouses {
-		f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
-	}
+	lh.queryChan <- ip
 }
 
 func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
@@ -594,11 +583,10 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 	if tree == nil {
 		return false
 	}
-	value := tree.MostSpecificContains(vpnIp)
-	if value == nil {
+	ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+	if !ok {
 		return false
 	}
-	calculatedRemotes := value.([]*calculatedRemote)
 
 	var calculated []*Ip4AndPort
 	for _, cr := range calculatedRemotes {
@@ -753,6 +741,46 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
 	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 
+func (lh *LightHouse) startQueryWorker() {
+	if lh.amLighthouse {
+		return
+	}
+
+	go func() {
+		nb := make([]byte, 12, 12)
+		out := make([]byte, mtu)
+
+		for {
+			select {
+			case <-lh.ctx.Done():
+				return
+			case ip := <-lh.queryChan:
+				lh.innerQueryServer(ip, nb, out)
+			}
+		}
+	}()
+}
+
+func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
+	if lh.IsLighthouseIP(ip) {
+		return
+	}
+
+	// Send a query to the lighthouses and hope for the best next time
+	query, err := NewLhQueryByInt(ip).Marshal()
+	if err != nil {
+		lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
+		return
+	}
+
+	lighthouses := lh.GetLighthouses()
+	lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
+
+	for n := range lighthouses {
+		lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
+	}
+}
+
 func (lh *LightHouse) StartUpdateWorker() {
 	interval := lh.GetUpdateInterval()
 	if lh.amLighthouse || interval == 0 {

+ 23 - 4
main.go

@@ -18,7 +18,7 @@ import (
 
 type m map[string]interface{}
 
-func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, tunFd *int) (retcon *Control, reterr error) {
+func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())
 	// Automatically cancel the context if Main returns an error, to signal all created goroutines to quit.
 	defer func() {
@@ -65,12 +65,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
-	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
+	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
 	// TODO: make sure mask is 4 bytes
 	tunCidr := certificate.Details.Ips[0]
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
+	if err != nil {
+		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
+	}
 	wireSSHReload(l, ssh, c)
 	var sshStart func()
 	if c.GetBool("sshd.enabled", false) {
@@ -125,7 +128,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if !configTest {
 		c.CatchHUP(ctx)
 
-		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
+		if deviceFactory == nil {
+			deviceFactory = overlay.NewDeviceFromConfig
+		}
+
+		tun, err = deviceFactory(c, l, tunCidr, routines)
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
@@ -156,12 +163,23 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 
 		for i := 0; i < routines; i++ {
+			l.Infof("listening %q %d", listenHost.IP, port)
 			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
 			udpServer.ReloadConfig(c)
 			udpConns[i] = udpServer
+
+			// If port is dynamic, discover it before the next pass through the for loop
+			// This way all routines will use the same port correctly
+			if port == 0 {
+				uPort, err := udpServer.LocalAddr()
+				if err != nil {
+					return nil, util.NewContextualError("Failed to get listening port", nil, err)
+				}
+				port = int(uPort.Port)
+			}
 		}
 	}
 
@@ -270,7 +288,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
-		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
 		punchy:                  punchy,
 
@@ -333,6 +350,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		c.RegisterReloadCallback(loadMultiPortConfig)
 
 		ifce.RegisterConfigChangeCallbacks(c)
+		ifce.reloadDisconnectInvalid(c)
 		ifce.reloadSendRecvError(c)
 
 		handshakeManager.f = ifce
@@ -365,6 +383,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	return &Control{
 		ifce,
 		l,
+		ctx,
 		cancel,
 		sshStart,
 		statsStart,

+ 4 - 5
outside.go

@@ -198,7 +198,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		HandleIncomingHandshake(f, addr, via, packet, h, hostinfo)
+		f.handshakeManager.HandleIncoming(addr, via, packet, h)
 		return
 
 	case header.RecvError:
@@ -419,7 +419,9 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 
 	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason != nil {
-		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
+		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
+		// This gives us a buffer to build the reject packet in
+		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
 		if f.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 				WithField("reason", dropReason).
@@ -468,9 +470,6 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 		return
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	if !hostinfo.RecvErrorExceeded() {
 		return
 	}

+ 2 - 2
overlay/route.go

@@ -21,8 +21,8 @@ type Route struct {
 	Install bool
 }
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
-	routeTree := cidr.NewTree4()
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
+	routeTree := cidr.NewTree4[iputil.VpnIp]()
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)

+ 8 - 10
overlay/route_test.go

@@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) {
 	assert.NoError(t, err)
 
 	ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
-	r := routeTree.MostSpecificContains(ip)
-	assert.NotNil(t, r)
-	assert.IsType(t, iputil.VpnIp(0), r)
-	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
+	ok, r := routeTree.MostSpecificContains(ip)
+	assert.True(t, ok)
+	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
 
 	ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
-	r = routeTree.MostSpecificContains(ip)
-	assert.NotNil(t, r)
-	assert.IsType(t, iputil.VpnIp(0), r)
-	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
+	ok, r = routeTree.MostSpecificContains(ip)
+	assert.True(t, ok)
+	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
 
 	ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
-	r = routeTree.MostSpecificContains(ip)
-	assert.Nil(t, r)
+	ok, r = routeTree.MostSpecificContains(ip)
+	assert.False(t, ok)
 }

+ 24 - 8
overlay/tun.go

@@ -10,7 +10,9 @@ import (
 
 const DefaultMTU = 1300
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *int, routines int) (Device, error) {
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
+
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
 	routes, err := parseRoutes(c, tunCidr)
 	if err != nil {
 		return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
@@ -27,27 +29,41 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, fd *
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 
-	case fd != nil:
-		return newTunFromFd(
+	default:
+		return newTun(
 			l,
-			*fd,
+			c.GetString("tun.dev", ""),
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
+			routines > 1,
 			c.GetBool("tun.use_system_route_table", false),
 		)
+	}
+}
 
-	default:
-		return newTun(
+func NewFdDeviceFromConfig(fd *int) DeviceFactory {
+	return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+		routes, err := parseRoutes(c, tunCidr)
+		if err != nil {
+			return nil, util.NewContextualError("Could not parse tun.routes", nil, err)
+		}
+
+		unsafeRoutes, err := parseUnsafeRoutes(c, tunCidr)
+		if err != nil {
+			return nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
+		}
+		routes = append(routes, unsafeRoutes...)
+		return newTunFromFd(
 			l,
-			c.GetString("tun.dev", ""),
+			*fd,
 			tunCidr,
 			c.GetInt("tun.mtu", DefaultMTU),
 			routes,
 			c.GetInt("tun.tx_queue", 500),
-			routines > 1,
 			c.GetBool("tun.use_system_route_table", false),
 		)
+
 	}
 }

+ 3 - 7
overlay/tun_android.go

@@ -18,7 +18,7 @@ type tun struct {
 	io.ReadWriteCloser
 	fd        int
 	cidr      *net.IPNet
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 }
 
@@ -46,12 +46,8 @@ func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t tun) Activate() error {

+ 4 - 4
overlay/tun_darwin.go

@@ -25,7 +25,7 @@ type tun struct {
 	cidr       *net.IPNet
 	DefaultMTU int
 	Routes     []Route
-	routeTree  *cidr.Tree4
+	routeTree  *cidr.Tree4[iputil.VpnIp]
 	l          *logrus.Logger
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -304,9 +304,9 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
+	ok, r := t.routeTree.MostSpecificContains(ip)
+	if ok {
+		return r
 	}
 
 	return 0

+ 3 - 7
overlay/tun_freebsd.go

@@ -48,7 +48,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -192,12 +192,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_ios.go

@@ -20,7 +20,7 @@ import (
 type tun struct {
 	io.ReadWriteCloser
 	cidr      *net.IPNet
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 }
 
 func newTun(_ *logrus.Logger, _ string, _ *net.IPNet, _ int, _ []Route, _ int, _ bool, _ bool) (*tun, error) {
@@ -46,12 +46,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 // The following is hoisted up from water, we do this so we can inject our own fd on iOS

+ 5 - 9
overlay/tun_linux.go

@@ -30,7 +30,7 @@ type tun struct {
 	TXQueueLen int
 
 	Routes          []Route
-	routeTree       atomic.Pointer[cidr.Tree4]
+	routeTree       atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
 	routeChan       chan struct{}
 	useSystemRoutes bool
 
@@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.Load().MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.Load().MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Write(b []byte) (int, error) {
@@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	newTree := cidr.NewTree4()
+	newTree := cidr.NewTree4[iputil.VpnIp]()
 	if r.Type == unix.RTM_NEWROUTE {
 		for _, oldR := range t.routeTree.Load().List() {
 			newTree.AddCIDR(oldR.CIDR, oldR.Value)
@@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	} else {
 		gw := iputil.Ip2VpnIp(r.Gw)
 		for _, oldR := range t.routeTree.Load().List() {
-			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
+			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
 				// This is the record to delete
 				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 				continue

+ 3 - 7
overlay/tun_netbsd.go

@@ -29,7 +29,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -134,12 +134,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_openbsd.go

@@ -23,7 +23,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -115,12 +115,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_tester.go

@@ -19,7 +19,7 @@ type TestTun struct {
 	Device    string
 	cidr      *net.IPNet
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	closed    atomic.Bool
@@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte {
 //********************************************************************************************************************//
 
 func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *TestTun) Activate() error {

+ 3 - 7
overlay/tun_water_windows.go

@@ -18,7 +18,7 @@ type waterTun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 
 	*water.Interface
 }
@@ -97,12 +97,8 @@ func (t *waterTun) Activate() error {
 }
 
 func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *waterTun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_wintun_windows.go

@@ -24,7 +24,7 @@ type winTun struct {
 	prefix    netip.Prefix
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 
 	tun *wintun.NativeTun
 }
@@ -146,12 +146,8 @@ func (t *winTun) Activate() error {
 }
 
 func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *winTun) Cidr() *net.IPNet {

+ 63 - 0
overlay/user.go

@@ -0,0 +1,63 @@
+package overlay
+
+import (
+	"io"
+	"net"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/iputil"
+)
+
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+	return NewUserDevice(tunCidr)
+}
+
+func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+	// these pipes guarantee each write/read will match 1:1
+	or, ow := io.Pipe()
+	ir, iw := io.Pipe()
+	return &UserDevice{
+		tunCidr:        tunCidr,
+		outboundReader: or,
+		outboundWriter: ow,
+		inboundReader:  ir,
+		inboundWriter:  iw,
+	}, nil
+}
+
+type UserDevice struct {
+	tunCidr *net.IPNet
+
+	outboundReader *io.PipeReader
+	outboundWriter *io.PipeWriter
+
+	inboundReader *io.PipeReader
+	inboundWriter *io.PipeWriter
+}
+
+func (d *UserDevice) Activate() error {
+	return nil
+}
+func (d *UserDevice) Cidr() *net.IPNet                      { return d.tunCidr }
+func (d *UserDevice) Name() string                          { return "faketun0" }
+func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return d, nil
+}
+
+func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {
+	return d.inboundReader, d.outboundWriter
+}
+
+func (d *UserDevice) Read(p []byte) (n int, err error) {
+	return d.outboundReader.Read(p)
+}
+func (d *UserDevice) Write(p []byte) (n int, err error) {
+	return d.inboundWriter.Write(p)
+}
+func (d *UserDevice) Close() error {
+	d.inboundWriter.Close()
+	d.outboundWriter.Close()
+	return nil
+}

+ 36 - 0
service/listener.go

@@ -0,0 +1,36 @@
+package service
+
+import (
+	"io"
+	"net"
+)
+
+type tcpListener struct {
+	port   uint16
+	s      *Service
+	addr   *net.TCPAddr
+	accept chan net.Conn
+}
+
+func (l *tcpListener) Accept() (net.Conn, error) {
+	conn, ok := <-l.accept
+	if !ok {
+		return nil, io.EOF
+	}
+	return conn, nil
+}
+
+func (l *tcpListener) Close() error {
+	l.s.mu.Lock()
+	defer l.s.mu.Unlock()
+	delete(l.s.mu.listeners, uint16(l.addr.Port))
+
+	close(l.accept)
+
+	return nil
+}
+
+// Addr returns the listener's network address.
+func (l *tcpListener) Addr() net.Addr {
+	return l.addr
+}

+ 248 - 0
service/service.go

@@ -0,0 +1,248 @@
+package service
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"log"
+	"math"
+	"net"
+	"os"
+	"strings"
+	"sync"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/overlay"
+	"golang.org/x/sync/errgroup"
+	"gvisor.dev/gvisor/pkg/bufferv2"
+	"gvisor.dev/gvisor/pkg/tcpip"
+	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+	"gvisor.dev/gvisor/pkg/tcpip/header"
+	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+	"gvisor.dev/gvisor/pkg/waiter"
+)
+
+const nicID = 1
+
+type Service struct {
+	eg      *errgroup.Group
+	control *nebula.Control
+	ipstack *stack.Stack
+
+	mu struct {
+		sync.Mutex
+
+		listeners map[uint16]*tcpListener
+	}
+}
+
+func New(config *config.C) (*Service, error) {
+	logger := logrus.New()
+	logger.Out = os.Stdout
+
+	control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
+	if err != nil {
+		return nil, err
+	}
+	control.Start()
+
+	ctx := control.Context()
+	eg, ctx := errgroup.WithContext(ctx)
+	s := Service{
+		eg:      eg,
+		control: control,
+	}
+	s.mu.listeners = map[uint16]*tcpListener{}
+
+	device, ok := control.Device().(*overlay.UserDevice)
+	if !ok {
+		return nil, errors.New("must be using user device")
+	}
+
+	s.ipstack = stack.New(stack.Options{
+		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
+	})
+	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+	tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+	if tcpipErr != nil {
+		return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+	}
+	linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
+	if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
+		return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
+	}
+	ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
+	s.ipstack.SetRouteTable([]tcpip.Route{
+		{
+			Destination: ipv4Subnet,
+			NIC:         nicID,
+		},
+	})
+
+	ipNet := device.Cidr()
+	pa := tcpip.ProtocolAddress{
+		AddressWithPrefix: tcpip.Address(ipNet.IP).WithPrefix(),
+		Protocol:          ipv4.ProtocolNumber,
+	}
+	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
+		PEB:        stack.CanBePrimaryEndpoint, // zero value default
+		ConfigType: stack.AddressConfigStatic,  // zero value default
+	}); err != nil {
+		return nil, fmt.Errorf("error creating IP: %s", err)
+	}
+
+	const tcpReceiveBufferSize = 0
+	const maxInFlightConnectionAttempts = 1024
+	tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
+	s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
+
+	reader, writer := device.Pipe()
+
+	go func() {
+		<-ctx.Done()
+		reader.Close()
+		writer.Close()
+	}()
+
+	// create Goroutines to forward packets between Nebula and Gvisor
+	eg.Go(func() error {
+		buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
+		for {
+			// this will read exactly one packet
+			n, err := reader.Read(buf)
+			if err != nil {
+				return err
+			}
+			packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+				Payload: bufferv2.MakeWithData(bytes.Clone(buf[:n])),
+			})
+			linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
+
+			if err := ctx.Err(); err != nil {
+				return err
+			}
+		}
+	})
+	eg.Go(func() error {
+		for {
+			packet := linkEP.ReadContext(ctx)
+			if packet.IsNil() {
+				if err := ctx.Err(); err != nil {
+					return err
+				}
+				continue
+			}
+			bufView := packet.ToView()
+			if _, err := bufView.WriteTo(writer); err != nil {
+				return err
+			}
+			bufView.Release()
+		}
+	})
+
+	return &s, nil
+}
+
+// DialContext dials the provided address. Currently only TCP is supported.
+func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+	if network != "tcp" && network != "tcp4" {
+		return nil, errors.New("only tcp is supported")
+	}
+
+	addr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+
+	fullAddr := tcpip.FullAddress{
+		NIC:  nicID,
+		Addr: tcpip.Address(addr.IP),
+		Port: uint16(addr.Port),
+	}
+
+	return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
+}
+
+// Listen listens on the provided address. Currently only TCP with wildcard
+// addresses are supported.
+func (s *Service) Listen(network, address string) (net.Listener, error) {
+	if network != "tcp" && network != "tcp4" {
+		return nil, errors.New("only tcp is supported")
+	}
+	addr, err := net.ResolveTCPAddr(network, address)
+	if err != nil {
+		return nil, err
+	}
+	if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
+		return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
+	}
+	if addr.Port == 0 {
+		return nil, errors.New("specific port required, got 0")
+	}
+	if addr.Port < 0 || addr.Port >= math.MaxUint16 {
+		return nil, fmt.Errorf("invalid port %d", addr.Port)
+	}
+	port := uint16(addr.Port)
+
+	l := &tcpListener{
+		port:   port,
+		s:      s,
+		addr:   addr,
+		accept: make(chan net.Conn),
+	}
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	if _, ok := s.mu.listeners[port]; ok {
+		return nil, fmt.Errorf("already listening on port %d", port)
+	}
+	s.mu.listeners[port] = l
+
+	return l, nil
+}
+
+func (s *Service) Wait() error {
+	return s.eg.Wait()
+}
+
+func (s *Service) Close() error {
+	s.control.Stop()
+	return nil
+}
+
+func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
+	endpointID := r.ID()
+
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	l, ok := s.mu.listeners[endpointID.LocalPort]
+	if !ok {
+		r.Complete(true)
+		return
+	}
+
+	var wq waiter.Queue
+	ep, err := r.CreateEndpoint(&wq)
+	if err != nil {
+		log.Printf("got error creating endpoint %q", err)
+		r.Complete(true)
+		return
+	}
+	r.Complete(false)
+	ep.SocketOptions().SetKeepAlive(true)
+
+	conn := gonet.NewTCPConn(&wq, ep)
+	l.accept <- conn
+}

+ 165 - 0
service/service_test.go

@@ -0,0 +1,165 @@
+package service
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"net"
+	"testing"
+	"time"
+
+	"dario.cat/mergo"
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/e2e"
+	"golang.org/x/sync/errgroup"
+	"gopkg.in/yaml.v2"
+)
+
+type m map[string]interface{}
+
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
+
+	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
+	copy(vpnIpNet.IP, udpIp)
+
+	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	caB, err := caCrt.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	mc := m{
+		"pki": m{
+			"ca":   string(caB),
+			"cert": string(myPEM),
+			"key":  string(myPrivKey),
+		},
+		//"tun": m{"disabled": true},
+		"firewall": m{
+			"outbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+			"inbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+		},
+		"timers": m{
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
+		},
+		"handshakes": m{
+			"try_interval": "200ms",
+		},
+	}
+
+	if overrides != nil {
+		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = overrides
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	var c config.C
+	if err := c.LoadString(string(cb)); err != nil {
+		panic(err)
+	}
+
+	s, err := New(&c)
+	if err != nil {
+		panic(err)
+	}
+	return s
+}
+
+func TestService(t *testing.T) {
+	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+		"static_host_map": m{},
+		"lighthouse": m{
+			"am_lighthouse": true,
+		},
+		"listen": m{
+			"host": "0.0.0.0",
+			"port": 4243,
+		},
+	})
+	b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+		"static_host_map": m{
+			"10.0.0.1": []string{"localhost:4243"},
+		},
+		"lighthouse": m{
+			"hosts":    []string{"10.0.0.1"},
+			"interval": 1,
+		},
+	})
+
+	ln, err := a.Listen("tcp", ":1234")
+	if err != nil {
+		t.Fatal(err)
+	}
+	var eg errgroup.Group
+	eg.Go(func() error {
+		conn, err := ln.Accept()
+		if err != nil {
+			return err
+		}
+		defer conn.Close()
+
+		t.Log("accepted connection")
+
+		if _, err := conn.Write([]byte("server msg")); err != nil {
+			return err
+		}
+
+		t.Log("server: wrote message")
+
+		data := make([]byte, 100)
+		n, err := conn.Read(data)
+		if err != nil {
+			return err
+		}
+		data = data[:n]
+		if !bytes.Equal(data, []byte("client msg")) {
+			return errors.New("got invalid message from client")
+		}
+		t.Log("server: read message")
+		return conn.Close()
+	})
+
+	c, err := b.DialContext(context.Background(), "tcp", "10.0.0.1:1234")
+	if err != nil {
+		t.Fatal(err)
+	}
+	if _, err := c.Write([]byte("client msg")); err != nil {
+		t.Fatal(err)
+	}
+
+	data := make([]byte, 100)
+	n, err := c.Read(data)
+	if err != nil {
+		t.Fatal(err)
+	}
+	data = data[:n]
+	if !bytes.Equal(data, []byte("server msg")) {
+		t.Fatal("got invalid message from client")
+	}
+
+	if err := c.Close(); err != nil {
+		t.Fatal(err)
+	}
+
+	if err := eg.Wait(); err != nil {
+		t.Fatal(err)
+	}
+}

+ 2 - 3
ssh.go

@@ -6,7 +6,6 @@ import (
 	"errors"
 	"flag"
 	"fmt"
-	"io/ioutil"
 	"net"
 	"os"
 	"reflect"
@@ -96,7 +95,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 		return nil, fmt.Errorf("sshd.host_key must be provided")
 	}
 
-	hostKeyBytes, err := ioutil.ReadFile(hostKeyFile)
+	hostKeyBytes, err := os.ReadFile(hostKeyFile)
 	if err != nil {
 		return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
 	}
@@ -519,7 +518,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 	}
 
 	var cm *CacheMap
-	rl := ifce.lightHouse.Query(vpnIp, ifce)
+	rl := ifce.lightHouse.Query(vpnIp)
 	if rl != nil {
 		cm = rl.CopyCache()
 	}

+ 2 - 2
test/logger.go

@@ -1,7 +1,7 @@
 package test
 
 import (
-	"io/ioutil"
+	"io"
 	"os"
 
 	"github.com/sirupsen/logrus"
@@ -12,7 +12,7 @@ func NewLogger() *logrus.Logger {
 
 	v := os.Getenv("TEST_LOGS")
 	if v == "" {
-		l.SetOutput(ioutil.Discard)
+		l.SetOutput(io.Discard)
 		return l
 	}