Browse Source

Merge remote-tracking branch 'origin/master' into multiport

Wade Simmons 2 weeks ago
parent
commit
510a8912a9
73 changed files with 3978 additions and 1145 deletions
  1. 3 3
      .github/workflows/gofmt.yml
  2. 16 16
      .github/workflows/release.yml
  3. 3 3
      .github/workflows/smoke-extra.yml
  4. 3 3
      .github/workflows/smoke.yml
  5. 18 18
      .github/workflows/test.yml
  6. 80 3
      CHANGELOG.md
  7. 32 77
      bits.go
  8. 86 23
      bits_test.go
  9. 3 8
      calculated_remote.go
  10. 4 2
      cert/cert.go
  11. 8 2
      cert/cert_v1.go
  12. 54 0
      cert/cert_v1_test.go
  13. 8 2
      cert/cert_v2.go
  14. 53 0
      cert/cert_v2_test.go
  15. 1 0
      cert/errors.go
  16. 42 12
      cert/pem.go
  17. 17 2
      cert/pem_test.go
  18. 3 9
      cert/sign.go
  19. 27 0
      cert_test/cert.go
  20. 17 14
      cmd/nebula-cert/ca.go
  21. 11 0
      cmd/nebula-cert/ca_test.go
  22. 18 0
      cmd/nebula-cert/main.go
  23. 35 29
      cmd/nebula-cert/sign.go
  24. 22 2
      cmd/nebula-cert/sign_test.go
  25. 13 0
      cmd/nebula-service/main.go
  26. 13 0
      cmd/nebula/main.go
  27. 1 1
      config/config.go
  28. 1 1
      config/config_test.go
  29. 60 24
      connection_manager.go
  30. 5 1
      connection_manager_test.go
  31. 1 6
      connection_state.go
  32. 4 0
      control_tester.go
  33. 209 2
      e2e/handshakes_test.go
  34. 158 13
      e2e/helpers_test.go
  35. 310 0
      e2e/tunnels_test.go
  36. 3 2
      examples/config.yml
  37. 125 71
      firewall.go
  38. 586 54
      firewall_test.go
  39. 19 19
      go.mod
  40. 38 33
      go.sum
  41. 146 154
      handshake_ix.go
  42. 14 13
      handshake_manager.go
  43. 56 28
      hostmap.go
  44. 6 5
      inside.go
  45. 8 1
      interface.go
  46. 133 115
      lighthouse.go
  47. 121 1
      lighthouse_test.go
  48. 24 2
      main.go
  49. 51 47
      outside.go
  50. 1 0
      overlay/device.go
  51. 50 0
      overlay/tun.go
  52. 4 0
      overlay/tun_android.go
  53. 4 12
      overlay/tun_darwin.go
  54. 4 0
      overlay/tun_disabled.go
  55. 383 68
      overlay/tun_freebsd.go
  56. 4 0
      overlay/tun_ios.go
  57. 52 28
      overlay/tun_linux.go
  58. 372 63
      overlay/tun_netbsd.go
  59. 315 97
      overlay/tun_openbsd.go
  60. 4 0
      overlay/tun_tester.go
  61. 4 0
      overlay/tun_windows.go
  62. 4 0
      overlay/user.go
  63. 1 0
      pkclient/pkclient_cgo.go
  64. 54 42
      pki.go
  65. 21 10
      remote_list.go
  66. 1 1
      service/service_test.go
  67. 4 0
      test/tun.go
  68. 4 0
      udp/conn.go
  69. 7 3
      udp/udp_darwin.go
  70. 4 0
      udp/udp_generic.go
  71. 4 0
      udp/udp_linux.go
  72. 4 0
      udp/udp_rio_windows.go
  73. 4 0
      udp/udp_tester.go

+ 3 - 3
.github/workflows/gofmt.yml

@@ -14,11 +14,11 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version: '1.24'
+        go-version: '1.25'
         check-latest: true
 
     - name: Install goimports

+ 16 - 16
.github/workflows/release.yml

@@ -10,11 +10,11 @@ jobs:
     name: Build Linux/BSD All
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v4
+      - uses: actions/checkout@v5
 
-      - uses: actions/setup-go@v5
+      - uses: actions/setup-go@v6
         with:
-          go-version: '1.24'
+          go-version: '1.25'
           check-latest: true
 
       - name: Build
@@ -24,7 +24,7 @@ jobs:
           mv build/*.tar.gz release
 
       - name: Upload artifacts
-        uses: actions/upload-artifact@v4
+        uses: actions/upload-artifact@v5
         with:
           name: linux-latest
           path: release
@@ -33,11 +33,11 @@ jobs:
     name: Build Windows
     runs-on: windows-latest
     steps:
-      - uses: actions/checkout@v4
+      - uses: actions/checkout@v5
 
-      - uses: actions/setup-go@v5
+      - uses: actions/setup-go@v6
         with:
-          go-version: '1.24'
+          go-version: '1.25'
           check-latest: true
 
       - name: Build
@@ -55,7 +55,7 @@ jobs:
           mv dist\windows\wintun build\dist\windows\
 
       - name: Upload artifacts
-        uses: actions/upload-artifact@v4
+        uses: actions/upload-artifact@v5
         with:
           name: windows-latest
           path: build
@@ -66,11 +66,11 @@ jobs:
       HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
     runs-on: macos-latest
     steps:
-      - uses: actions/checkout@v4
+      - uses: actions/checkout@v5
 
-      - uses: actions/setup-go@v5
+      - uses: actions/setup-go@v6
         with:
-          go-version: '1.24'
+          go-version: '1.25'
           check-latest: true
 
       - name: Import certificates
@@ -104,7 +104,7 @@ jobs:
           fi
 
       - name: Upload artifacts
-        uses: actions/upload-artifact@v4
+        uses: actions/upload-artifact@v5
         with:
           name: darwin-latest
           path: ./release/*
@@ -124,11 +124,11 @@ jobs:
       # be overwritten
       - name: Checkout code
         if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
-        uses: actions/checkout@v4
+        uses: actions/checkout@v5
 
       - name: Download artifacts
         if: ${{ env.HAS_DOCKER_CREDS == 'true' }}
-        uses: actions/download-artifact@v4
+        uses: actions/download-artifact@v6
         with:
           name: linux-latest
           path: artifacts
@@ -160,10 +160,10 @@ jobs:
     needs: [build-linux, build-darwin, build-windows]
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v4
+      - uses: actions/checkout@v5
 
       - name: Download artifacts
-        uses: actions/download-artifact@v4
+        uses: actions/download-artifact@v6
         with:
           path: artifacts
 

+ 3 - 3
.github/workflows/smoke-extra.yml

@@ -20,11 +20,11 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version-file: 'go.mod'
+        go-version: '1.25'
         check-latest: true
 
     - name: add hashicorp source

+ 3 - 3
.github/workflows/smoke.yml

@@ -18,11 +18,11 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version: '1.24'
+        go-version: '1.25'
         check-latest: true
 
     - name: build

+ 18 - 18
.github/workflows/test.yml

@@ -18,11 +18,11 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version: '1.24'
+        go-version: '1.25'
         check-latest: true
 
     - name: Build
@@ -32,9 +32,9 @@ jobs:
       run: make vet
 
     - name: golangci-lint
-      uses: golangci/golangci-lint-action@v8
+      uses: golangci/golangci-lint-action@v9
       with:
-        version: v2.1
+        version: v2.5
 
     - name: Test
       run: make test
@@ -45,7 +45,7 @@ jobs:
     - name: Build test mobile
       run: make build-test-mobile
 
-    - uses: actions/upload-artifact@v4
+    - uses: actions/upload-artifact@v5
       with:
         name: e2e packet flow linux-latest
         path: e2e/mermaid/linux-latest
@@ -56,11 +56,11 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version: '1.24'
+        go-version: '1.25'
         check-latest: true
 
     - name: Build
@@ -77,11 +77,11 @@ jobs:
     runs-on: ubuntu-latest
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version: '1.22'
+        go-version: '1.25'
         check-latest: true
 
     - name: Build
@@ -98,11 +98,11 @@ jobs:
         os: [windows-latest, macos-latest]
     steps:
 
-    - uses: actions/checkout@v4
+    - uses: actions/checkout@v5
 
-    - uses: actions/setup-go@v5
+    - uses: actions/setup-go@v6
       with:
-        go-version: '1.24'
+        go-version: '1.25'
         check-latest: true
 
     - name: Build nebula
@@ -115,9 +115,9 @@ jobs:
       run: make vet
 
     - name: golangci-lint
-      uses: golangci/golangci-lint-action@v8
+      uses: golangci/golangci-lint-action@v9
       with:
-        version: v2.1
+        version: v2.5
 
     - name: Test
       run: make test
@@ -125,7 +125,7 @@ jobs:
     - name: End 2 end
       run: make e2evv
 
-    - uses: actions/upload-artifact@v4
+    - uses: actions/upload-artifact@v5
       with:
         name: e2e packet flow ${{ matrix.os }}
         path: e2e/mermaid/${{ matrix.os }}

+ 80 - 3
CHANGELOG.md

@@ -7,12 +7,85 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+## [1.10.0] - 2025-12-04
+
+See the [v1.10.0](https://github.com/slackhq/nebula/milestone/16?closed=1) milestone for a complete list of changes.
+
+### Added
+
+- Support for ipv6 and multiple ipv4/6 addresses in the overlay.
+  A new v2 ASN.1 based certificate format.
+  Certificates now have a unified interface for external implementations.
+  (#1212, #1216, #1345, #1359, #1381, #1419, #1464, #1466, #1451, #1476, #1467, #1481, #1399, #1488, #1492, #1495, #1468, #1521, #1535, #1538)
+- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
+- Add ECMP support for `unsafe_routes`. (#1332)
+- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153, #1482)
+
 ### Changed
 
-- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
+- **NOTE**: `default_local_cidr_any` now defaults to false, meaning that any firewall rule
   intended to target an `unsafe_routes` entry must explicitly declare it via the
   `local_cidr` field. This is almost always the intended behavior. This flag is
-  deprecated and will be removed in a future release.
+  deprecated and will be removed in a future release. (#1373)
+- Improve logging when a relay is in use on an inbound packet. (#1533)
+- Avoid fatal errors if `rountines` is > 1 on systems that don't support more than 1 routine. (#1531)
+- Log a warning if a firewall rule contains an `any` that negates a more restrictive filter. (#1513)
+- Accept encrypted CA passphrase from an environment variable. (#1421)
+- Allow handshaking with any trusted remote. (#1509)
+- Log only the count of blocklisted certificate fingerprints instead of the entire list. (#1525)
+- Don't fatal when the ssh server is unable to be configured successfully. (#1520)
+- Update to build against go v1.25. (#1483)
+- Allow projects using `nebula` as a library with userspace networking to configure the `logger` and build version. (#1239) 
+- Upgrade to `yaml.v3`. (#1148, #1371, #1438, #1478)
+
+### Fixed
+
+- Fix a potential bug with udp ipv4 only on darwin. (#1532)
+- Improve lost packet statistics. (#1441, #1537)
+- Honor `remote_allow_list` in hole punch response. (#1186)
+- Fix a panic when `tun.use_system_route_table` is `true` and a route lacks a destination. (#1437) 
+- Fix an issue when `tun.use_system_route_table: true` could result in heavy CPU utilization when many thousands of routes
+  are present. (#1326) 
+- Fix tests for 32 bit machines. (#1394)
+- Fix a possible 32bit integer underflow in config handling. (#1353)
+- Fix moving a udp address from one vpn address to another in the `static_host_map`
+  which could cause rapid re-handshaking with an incorrect remote. (#1259)
+- Improve smoke tests in environments where the docker network is not the default. (#1347)
+
+## [1.9.7] - 2025-10-10
+
+### Security
+
+- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
+  certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
+
+### Changed
+
+- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
+- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1453)
+
+## [1.9.6] - 2025-7-15
+
+### Added
+
+- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
+
+### Fixed
+
+- Fix Darwin freeze due to presence of some Network Extensions (#1426)
+- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
+- Fix Windows freeze due to ICMP error handling (#1412)
+- Fix relay migration panic (#1403)
+
+## [1.9.5] - 2024-12-05
+
+### Added
+
+- Gracefully ignore v2 certificates. (#1282)
+
+### Fixed
+
+- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
 
 ## [1.9.4] - 2024-09-09
 
@@ -671,7 +744,11 @@ created.)
 
 - Initial public release.
 
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
+[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
+[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
+[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
+[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
 [1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2

+ 32 - 77
bits.go

@@ -9,14 +9,13 @@ type Bits struct {
 	length             uint64
 	current            uint64
 	bits               []bool
-	firstSeen          bool
 	lostCounter        metrics.Counter
 	dupeCounter        metrics.Counter
 	outOfWindowCounter metrics.Counter
 }
 
 func NewBits(bits uint64) *Bits {
-	return &Bits{
+	b := &Bits{
 		length:             bits,
 		bits:               make([]bool, bits, bits),
 		current:            0,
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
 		dupeCounter:        metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
 		outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
 	}
+
+	// There is no counter value 0, mark it to avoid counting a lost packet later.
+	b.bits[0] = true
+	b.current = 0
+	return b
 }
 
-func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
+func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
 	// If i is the next number, return true.
-	if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
+	if i > b.current {
 		return true
 	}
 
-	// If i is within the window, check if it's been set already. The first window will fail this check
-	if i > b.current-b.length {
-		return !b.bits[i%b.length]
-	}
-
-	// If i is within the first window
-	if i < b.length {
+	// If i is within the window, check if it's been set already.
+	if i > b.current-b.length || i < b.length && b.current < b.length {
 		return !b.bits[i%b.length]
 	}
 
 	// Not within the window
-	l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
+	if l.Level >= logrus.DebugLevel {
+		l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
+	}
 	return false
 }
 
 func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 	// If i is the next number, return true and update current.
 	if i == b.current+1 {
-		// Report missed packets, we can only understand what was missed after the first window has been gone through
-		if i > b.length && b.bits[i%b.length] == false {
+		// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
+		// The very first window can only be tracked as lost once we are on the 2nd window or greater
+		if b.bits[i%b.length] == false && i > b.length {
 			b.lostCounter.Inc(1)
 		}
 		b.bits[i%b.length] = true
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 		return true
 	}
 
-	// If i packet is greater than current but less than the maximum length of our bitmap,
-	// flip everything in between to false and move ahead.
-	if i > b.current && i < b.current+b.length {
-		// In between current and i need to be zero'd to allow those packets to come in later
-		for n := b.current + 1; n < i; n++ {
+	// If i is a jump, adjust the window, record lost, update current, and return true
+	if i > b.current {
+		lost := int64(0)
+		// Zero out the bits between the current and the new counter value, limited by the window size,
+		// since the window is shifting
+		for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
+			if b.bits[n%b.length] == false && n > b.length {
+				lost++
+			}
 			b.bits[n%b.length] = false
 		}
 
-		b.bits[i%b.length] = true
-		b.current = i
-		//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
-		return true
-	}
-
-	// If i is greater than the delta between current and the total length of our bitmap,
-	// just flip everything in the map and move ahead.
-	if i >= b.current+b.length {
-		// The current window loss will be accounted for later, only record the jump as loss up until then
-		lost := maxInt64(0, int64(i-b.current-b.length))
-		//TODO: explain this
-		if b.current == 0 {
-			lost++
-		}
-
-		for n := range b.bits {
-			// Don't want to count the first window as a loss
-			//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
-			//if b.bits[n] == false {
-			//	lost++
-			//}
-			b.bits[n] = false
-		}
-
+		// Only record any skipped packets as a result of the window moving further than the window length
+		// Any loss within the new window will be accounted for in future calls
+		lost += max(0, int64(i-b.current-b.length))
 		b.lostCounter.Inc(lost)
 
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
-				Debug("Receive window")
-		}
 		b.bits[i%b.length] = true
 		b.current = i
 		return true
 	}
 
-	// Allow for the 0 packet to come in within the first window
-	if i == 0 && b.firstSeen == false && b.current < b.length {
-		b.firstSeen = true
-		b.bits[i%b.length] = true
-		return true
-	}
-
-	// If i is within the window of current minus length (the total pat window size),
-	// allow it and flip to true but to NOT change current. We also have to account for the first window
-	if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
-		if b.current == i {
+	// If i is within the current window but below the current counter,
+	// Check to see if it's a duplicate
+	if i > b.current-b.length || i < b.length && b.current < b.length {
+		if b.current == i || b.bits[i%b.length] == true {
 			if l.Level >= logrus.DebugLevel {
 				l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
 					Debug("Receive window")
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 			return false
 		}
 
-		if b.bits[i%b.length] == true {
-			if l.Level >= logrus.DebugLevel {
-				l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
-					Debug("Receive window")
-			}
-			b.dupeCounter.Inc(1)
-			return false
-		}
-
 		b.bits[i%b.length] = true
 		return true
-
 	}
 
 	// In all other cases, fail and don't change current.
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 	}
 	return false
 }
-
-func maxInt64(a, b int64) int64 {
-	if a > b {
-		return a
-	}
-
-	return b
-}

+ 86 - 23
bits_test.go

@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
 	assert.Len(t, b.bits, 10)
 
 	// This is initialized to zero - receive one. This should work.
-
 	assert.True(t, b.Check(l, 1))
-	u := b.Update(l, 1)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 1))
 	assert.EqualValues(t, 1, b.current)
-	g := []bool{false, true, false, false, false, false, false, false, false, false}
+	g := []bool{true, true, false, false, false, false, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Receive two
 	assert.True(t, b.Check(l, 2))
-	u = b.Update(l, 2)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 2))
 	assert.EqualValues(t, 2, b.current)
-	g = []bool{false, true, true, false, false, false, false, false, false, false}
+	g = []bool{true, true, true, false, false, false, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Receive two again - it will fail
 	assert.False(t, b.Check(l, 2))
-	u = b.Update(l, 2)
-	assert.False(t, u)
+	assert.False(t, b.Update(l, 2))
 	assert.EqualValues(t, 2, b.current)
 
 	// Jump ahead to 15, which should clear everything and set the 6th element
 	assert.True(t, b.Check(l, 15))
-	u = b.Update(l, 15)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 15))
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, false, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Mark 14, which is allowed because it is in the window
 	assert.True(t, b.Check(l, 14))
-	u = b.Update(l, 14)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 14))
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, true, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Mark 5, which is not allowed because it is not in the window
 	assert.False(t, b.Check(l, 5))
-	u = b.Update(l, 5)
-	assert.False(t, u)
+	assert.False(t, b.Update(l, 5))
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, true, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
 
 	// Walk through a few windows in order
 	b = NewBits(10)
-	for i := uint64(0); i <= 100; i++ {
+	for i := uint64(1); i <= 100; i++ {
 		assert.True(t, b.Check(l, i), "Error while checking %v", i)
 		assert.True(t, b.Update(l, i), "Error while updating %v", i)
 	}
+
+	assert.False(t, b.Check(l, 1), "Out of window check")
+}
+
+func TestBitsLargeJumps(t *testing.T) {
+	l := test.NewLogger()
+	b := NewBits(10)
+	b.lostCounter.Clear()
+
+	b = NewBits(10)
+	b.lostCounter.Clear()
+	assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
+	assert.Equal(t, int64(45), b.lostCounter.Count())
+
+	assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
+	assert.Equal(t, int64(89), b.lostCounter.Count())
+
+	assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
+	assert.Equal(t, int64(188), b.lostCounter.Count())
 }
 
 func TestBitsDupeCounter(t *testing.T) {
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
 	assert.False(t, b.Update(l, 0))
 	assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
 
-	//tODO: make sure lostcounter doesn't increase in orderly increment
-	assert.Equal(t, int64(20), b.lostCounter.Count())
+	assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
 	assert.Equal(t, int64(0), b.dupeCounter.Count())
 	assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
 }
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	//assert.True(t, b.Update(0))
-	assert.True(t, b.Update(l, 0))
 	assert.True(t, b.Update(l, 20))
 	assert.True(t, b.Update(l, 21))
 	assert.True(t, b.Update(l, 22))
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
 	assert.True(t, b.Update(l, 27))
 	assert.True(t, b.Update(l, 28))
 	assert.True(t, b.Update(l, 29))
-	assert.Equal(t, int64(20), b.lostCounter.Count())
+	assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
 	assert.Equal(t, int64(0), b.dupeCounter.Count())
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	assert.True(t, b.Update(l, 0))
-	assert.Equal(t, int64(0), b.lostCounter.Count())
 	assert.True(t, b.Update(l, 9))
 	assert.Equal(t, int64(0), b.lostCounter.Count())
 	// 10 will set 0 index, 0 was already set, no lost packets
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 }
 
+func TestBitsLostCounterIssue1(t *testing.T) {
+	l := test.NewLogger()
+	b := NewBits(10)
+	b.lostCounter.Clear()
+	b.dupeCounter.Clear()
+	b.outOfWindowCounter.Clear()
+
+	assert.True(t, b.Update(l, 4))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 1))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 9))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 2))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 3))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 5))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 6))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 7))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	// assert.True(t, b.Update(l, 8))
+	assert.True(t, b.Update(l, 10))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 11))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+
+	assert.True(t, b.Update(l, 14))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
+	assert.True(t, b.Update(l, 19))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 12))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 13))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 15))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 16))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 17))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 18))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 20))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 21))
+
+	// We missed packet 8 above
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.Equal(t, int64(0), b.dupeCounter.Count())
+	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
+}
+
 func BenchmarkBits(b *testing.B) {
 	z := NewBits(10)
 	for n := 0; n < b.N; n++ {

+ 3 - 8
calculated_remote.go

@@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 
 	calculatedRemotes := new(bart.Table[[]*calculatedRemote])
 
-	rawMap, ok := value.(map[any]any)
+	rawMap, ok := value.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
 	}
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
+	for rawCIDR, rawValue := range rawMap {
 		cidr, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
@@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat
 }
 
 func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
-	rawMap, ok := raw.(map[any]any)
+	rawMap, ok := raw.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
 	}

+ 4 - 2
cert/cert.go

@@ -58,6 +58,9 @@ type Certificate interface {
 	// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
 	PublicKey() []byte
 
+	// MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM
+	MarshalPublicKeyPEM() []byte
+
 	// Curve identifies which curve was used for the PublicKey and Signature.
 	Curve() Curve
 
@@ -135,8 +138,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific
 	case Version2:
 		c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
 	default:
-		//TODO: CERT-V2 make a static var
-		return nil, fmt.Errorf("unknown certificate version %d", v)
+		return nil, ErrUnknownVersion
 	}
 
 	if err != nil {

+ 8 - 2
cert/cert_v1.go

@@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte {
 	return c.details.publicKey
 }
 
+func (c *certificateV1) MarshalPublicKeyPEM() []byte {
+	return marshalCertPublicKeyToPEM(c)
+}
+
 func (c *certificateV1) Signature() []byte {
 	return c.signature
 }
@@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool {
 	case Curve_CURVE25519:
 		return ed25519.Verify(key, b, c.signature)
 	case Curve_P256:
-		x, y := elliptic.Unmarshal(elliptic.P256(), key)
-		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+		if err != nil {
+			return false
+		}
 		hashed := sha256.Sum256(b)
 		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
 	default:

+ 54 - 0
cert/cert_v1_test.go

@@ -1,6 +1,7 @@
 package cert
 
 import (
+	"crypto/ed25519"
 	"fmt"
 	"net/netip"
 	"testing"
@@ -13,6 +14,7 @@ import (
 )
 
 func TestCertificateV1_Marshal(t *testing.T) {
+	t.Parallel()
 	before := time.Now().Add(time.Second * -60).Round(time.Second)
 	after := time.Now().Add(time.Second * 60).Round(time.Second)
 	pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) {
 	assert.Equal(t, nc.Groups(), nc2.Groups())
 }
 
+func TestCertificateV1_PublicKeyPem(t *testing.T) {
+	t.Parallel()
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name:           "testing",
+			networks:       []netip.Prefix{},
+			unsafeNetworks: []netip.Prefix{},
+			groups:         []string{"test-group1", "test-group2", "test-group3"},
+			notBefore:      before,
+			notAfter:       after,
+			publicKey:      pubKey,
+			isCA:           false,
+			issuer:         "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	assert.Equal(t, Version1, nc.Version())
+	assert.Equal(t, Curve_CURVE25519, nc.Curve())
+	pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
+	assert.False(t, nc.IsCA())
+
+	nc.details.isCA = true
+	assert.Equal(t, Curve_CURVE25519, nc.Curve())
+	pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
+	assert.True(t, nc.IsCA())
+
+	pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PUBLIC KEY-----
+`)
+	pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
+	require.NoError(t, err)
+	nc.details.curve = Curve_P256
+	nc.details.publicKey = pubP256Key
+	assert.Equal(t, Curve_P256, nc.Curve())
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
+	assert.True(t, nc.IsCA())
+
+	nc.details.isCA = false
+	assert.Equal(t, Curve_P256, nc.Curve())
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
+	assert.False(t, nc.IsCA())
+}
+
 func TestCertificateV1_Expired(t *testing.T) {
 	nc := certificateV1{
 		details: detailsV1{

+ 8 - 2
cert/cert_v2.go

@@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte {
 	return c.publicKey
 }
 
+func (c *certificateV2) MarshalPublicKeyPEM() []byte {
+	return marshalCertPublicKeyToPEM(c)
+}
+
 func (c *certificateV2) Signature() []byte {
 	return c.signature
 }
@@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool {
 	case Curve_CURVE25519:
 		return ed25519.Verify(key, b, c.signature)
 	case Curve_P256:
-		x, y := elliptic.Unmarshal(elliptic.P256(), key)
-		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key)
+		if err != nil {
+			return false
+		}
 		hashed := sha256.Sum256(b)
 		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
 	default:

+ 53 - 0
cert/cert_v2_test.go

@@ -15,6 +15,7 @@ import (
 )
 
 func TestCertificateV2_Marshal(t *testing.T) {
+	t.Parallel()
 	before := time.Now().Add(time.Second * -60).Round(time.Second)
 	after := time.Now().Add(time.Second * 60).Round(time.Second)
 	pubKey := []byte("1234567890abcedfghij1234567890ab")
@@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
 	assert.Equal(t, nc.Groups(), nc2.Groups())
 }
 
+func TestCertificateV2_PublicKeyPem(t *testing.T) {
+	t.Parallel()
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name:           "testing",
+			networks:       []netip.Prefix{},
+			unsafeNetworks: []netip.Prefix{},
+			groups:         []string{"test-group1", "test-group2", "test-group3"},
+			notBefore:      before,
+			notAfter:       after,
+			isCA:           false,
+			issuer:         "1234567890abcedfghij1234567890ab",
+		},
+		publicKey: pubKey,
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	assert.Equal(t, Version2, nc.Version())
+	assert.Equal(t, Curve_CURVE25519, nc.Curve())
+	pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n"
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
+	assert.False(t, nc.IsCA())
+
+	nc.details.isCA = true
+	assert.Equal(t, Curve_CURVE25519, nc.Curve())
+	pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n"
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem)
+	assert.True(t, nc.IsCA())
+
+	pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PUBLIC KEY-----
+`)
+	pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem)
+	require.NoError(t, err)
+	nc.curve = Curve_P256
+	nc.publicKey = pubP256Key
+	assert.Equal(t, Curve_P256, nc.Curve())
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
+	assert.True(t, nc.IsCA())
+
+	nc.details.isCA = false
+	assert.Equal(t, Curve_P256, nc.Curve())
+	assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem))
+	assert.False(t, nc.IsCA())
+}
+
 func TestCertificateV2_Expired(t *testing.T) {
 	nc := certificateV2{
 		details: detailsV2{

+ 1 - 0
cert/errors.go

@@ -20,6 +20,7 @@ var (
 	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
 	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
 	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
+	ErrUnknownVersion             = errors.New("certificate version unrecognized")
 
 	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
 	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")

+ 42 - 12
cert/pem.go

@@ -7,19 +7,26 @@ import (
 	"golang.org/x/crypto/ed25519"
 )
 
-const (
-	CertificateBanner                = "NEBULA CERTIFICATE"
-	CertificateV2Banner              = "NEBULA CERTIFICATE V2"
-	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
-	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
-	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
-	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
-	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
-
-	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
-	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
+const ( //cert banners
+	CertificateBanner   = "NEBULA CERTIFICATE"
+	CertificateV2Banner = "NEBULA CERTIFICATE V2"
+)
+
+const ( //key-agreement-key banners
+	X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY"
+	X25519PublicKeyBanner  = "NEBULA X25519 PUBLIC KEY"
+	P256PrivateKeyBanner   = "NEBULA P256 PRIVATE KEY"
+	P256PublicKeyBanner    = "NEBULA P256 PUBLIC KEY"
+)
+
+/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */
+const ( //signing key banners
 	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
 	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+	ECDSAP256PublicKeyBanner           = "NEBULA ECDSA P256 PUBLIC KEY"
+	EncryptedEd25519PrivateKeyBanner   = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
+	Ed25519PrivateKeyBanner            = "NEBULA ED25519 PRIVATE KEY"
+	Ed25519PublicKeyBanner             = "NEBULA ED25519 PUBLIC KEY"
 )
 
 // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
@@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
 
 }
 
+func marshalCertPublicKeyToPEM(c Certificate) []byte {
+	if c.IsCA() {
+		return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
+	} else {
+		return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey())
+	}
+}
+
+// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH.
+// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
 func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
 	switch curve {
 	case Curve_CURVE25519:
@@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
 	}
 }
 
+// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing.
+// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes!
+func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
 func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
 	k, r := pem.Decode(b)
 	if k == nil {
@@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
 	case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
 		expectedLen = 32
 		curve = Curve_CURVE25519
-	case P256PublicKeyBanner:
+	case P256PublicKeyBanner, ECDSAP256PublicKeyBanner:
 		// Uncompressed
 		expectedLen = 65
 		curve = Curve_P256

+ 17 - 2
cert/pem_test.go

@@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 }
 
 func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
+	t.Parallel()
 	pubKey := []byte(`# A good key
 -----BEGIN NEBULA ED25519 PUBLIC KEY-----
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 }
 
 func TestUnmarshalX25519PublicKey(t *testing.T) {
+	t.Parallel()
 	pubKey := []byte(`# A good key
 -----BEGIN NEBULA X25519 PUBLIC KEY-----
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
@@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 AAAAAAAAAAAAAAAAAAAAAAA=
 -----END NEBULA P256 PUBLIC KEY-----
+`)
+	oldPubP256Key := []byte(`# A good key
+-----BEGIN NEBULA ECDSA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ECDSA P256 PUBLIC KEY-----
 `)
 	shortKey := []byte(`# A short key
 -----BEGIN NEBULA X25519 PUBLIC KEY-----
@@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
 -END NEBULA X25519 PUBLIC KEY-----`)
 
-	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
+	keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)
 
 	// Success test case
 	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
 	assert.Len(t, k, 32)
 	require.NoError(t, err)
-	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem))
 	assert.Equal(t, Curve_CURVE25519, curve)
 
+	// Success test case
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Len(t, k, 65)
+	require.NoError(t, err)
+	assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+
 	// Success test case
 	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
 	assert.Len(t, k, 65)

+ 3 - 9
cert/sign.go

@@ -7,7 +7,6 @@ import (
 	"crypto/rand"
 	"crypto/sha256"
 	"fmt"
-	"math/big"
 	"net/netip"
 	"time"
 )
@@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
 		}
 		return t.SignWith(signer, curve, sp)
 	case Curve_P256:
-		pk := &ecdsa.PrivateKey{
-			PublicKey: ecdsa.PublicKey{
-				Curve: elliptic.P256(),
-			},
-			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
-			D: new(big.Int).SetBytes(key),
+		pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key)
+		if err != nil {
+			return nil, err
 		}
-		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
-		pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
 		sp := func(certBytes []byte) ([]byte, error) {
 			// We need to hash first for ECDSA
 			// - https://pkg.go.dev/crypto/ecdsa#SignASN1

+ 27 - 0
cert_test/cert.go

@@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by
 	return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
 }
 
+func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) {
+	nc := &cert.TBSCertificate{
+		Version:        v,
+		Curve:          c.Curve(),
+		Name:           c.Name(),
+		Networks:       c.Networks(),
+		UnsafeNetworks: c.UnsafeNetworks(),
+		Groups:         c.Groups(),
+		NotBefore:      time.Unix(c.NotBefore().Unix(), 0),
+		NotAfter:       time.Unix(c.NotAfter().Unix(), 0),
+		PublicKey:      c.PublicKey(),
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pem
+}
+
 func X25519Keypair() ([]byte, []byte) {
 	privkey := make([]byte, 32)
 	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {

+ 17 - 14
cmd/nebula-cert/ca.go

@@ -173,23 +173,26 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 
 	var passphrase []byte
 	if !isP11 && *cf.encryption {
-		for i := 0; i < 5; i++ {
-			out.Write([]byte("Enter passphrase: "))
-			passphrase, err = pr.ReadPassword()
-
-			if err == ErrNoTerminal {
-				return fmt.Errorf("out-key must be encrypted interactively")
-			} else if err != nil {
-				return fmt.Errorf("error reading passphrase: %s", err)
-			}
+		passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
+		if len(passphrase) == 0 {
+			for i := 0; i < 5; i++ {
+				out.Write([]byte("Enter passphrase: "))
+				passphrase, err = pr.ReadPassword()
+
+				if err == ErrNoTerminal {
+					return fmt.Errorf("out-key must be encrypted interactively")
+				} else if err != nil {
+					return fmt.Errorf("error reading passphrase: %s", err)
+				}
 
-			if len(passphrase) > 0 {
-				break
+				if len(passphrase) > 0 {
+					break
+				}
 			}
-		}
 
-		if len(passphrase) == 0 {
-			return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
+			if len(passphrase) == 0 {
+				return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext")
+			}
 		}
 	}
 

+ 11 - 0
cmd/nebula-cert/ca_test.go

@@ -171,6 +171,17 @@ func Test_ca(t *testing.T) {
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Empty(t, eb.String())
 
+	// test encrypted key with passphrase environment variable
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
+	require.NoError(t, ca(args, ob, eb, testpw))
+	assert.Empty(t, eb.String())
+	os.Setenv("NEBULA_CA_PASSPHRASE", "")
+
 	// read encrypted key file and verify default params
 	rb, _ = os.ReadFile(keyF.Name())
 	k, _ := pem.Decode(rb)

+ 18 - 0
cmd/nebula-cert/main.go

@@ -5,10 +5,28 @@ import (
 	"fmt"
 	"io"
 	"os"
+	"runtime/debug"
+	"strings"
 )
 
+// A version string that can be set with
+//
+//	-ldflags "-X main.Build=SOMEVERSION"
+//
+// at compile-time.
 var Build string
 
+func init() {
+	if Build == "" {
+		info, ok := debug.ReadBuildInfo()
+		if !ok {
+			return
+		}
+
+		Build = strings.TrimPrefix(info.Main.Version, "v")
+	}
+}
+
 type helpError struct {
 	s string
 }

+ 35 - 29
cmd/nebula-cert/sign.go

@@ -43,7 +43,7 @@ type signFlags struct {
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
-	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -116,26 +116,28 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		// naively attempt to decode the private key as though it is not encrypted
 		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
 		if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
-			// ask for a passphrase until we get one
 			var passphrase []byte
-			for i := 0; i < 5; i++ {
-				out.Write([]byte("Enter passphrase: "))
-				passphrase, err = pr.ReadPassword()
-
-				if errors.Is(err, ErrNoTerminal) {
-					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
-				} else if err != nil {
-					return fmt.Errorf("error reading password: %s", err)
+			passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE"))
+			if len(passphrase) == 0 {
+				// ask for a passphrase until we get one
+				for i := 0; i < 5; i++ {
+					out.Write([]byte("Enter passphrase: "))
+					passphrase, err = pr.ReadPassword()
+
+					if errors.Is(err, ErrNoTerminal) {
+						return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
+					} else if err != nil {
+						return fmt.Errorf("error reading password: %s", err)
+					}
+
+					if len(passphrase) > 0 {
+						break
+					}
 				}
-
-				if len(passphrase) > 0 {
-					break
+				if len(passphrase) == 0 {
+					return fmt.Errorf("cannot open encrypted ca-key without passphrase")
 				}
 			}
-			if len(passphrase) == 0 {
-				return fmt.Errorf("cannot open encrypted ca-key without passphrase")
-			}
-
 			curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
 			if err != nil {
 				return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
@@ -165,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("ca certificate is expired")
 	}
 
+	if version == 0 {
+		version = caCert.Version()
+	}
+
 	// if no duration is given, expire one second before the root expires
 	if *sf.duration <= 0 {
 		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -277,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	notBefore := time.Now()
 	notAfter := notBefore.Add(*sf.duration)
 
-	if version == 0 || version == cert.Version1 {
-		// Make sure we at least have an ip
+	switch version {
+	case cert.Version1:
+		// Make sure we have only one ipv4 address
 		if len(v4Networks) != 1 {
 			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
 		}
 
-		if version == cert.Version1 {
-			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
-			if len(v6Networks) > 0 {
-				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
-			}
+		if len(v6Networks) > 0 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
+		}
 
-			if len(v6UnsafeNetworks) > 0 {
-				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
-			}
+		if len(v6UnsafeNetworks) > 0 {
+			return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
 		}
 
 		t := &cert.TBSCertificate{
@@ -321,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		}
 
 		crts = append(crts, nc)
-	}
 
-	if version == 0 || version == cert.Version2 {
+	case cert.Version2:
 		t := &cert.TBSCertificate{
 			Version:        cert.Version2,
 			Name:           *sf.name,
@@ -351,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		}
 
 		crts = append(crts, nc)
+	default:
+		// this should be unreachable
+		return fmt.Errorf("invalid version: %d", version)
 	}
 
 	if !isP11 && *sf.inPubPath == "" {

+ 22 - 2
cmd/nebula-cert/sign_test.go

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
 			"  -unsafe-networks string\n"+
 			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
 			"  -version uint\n"+
-			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
+			"    \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
 		ob.String(),
 	)
 }
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -379,6 +379,15 @@ func Test_signCert(t *testing.T) {
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 
+	// test with the proper password in the environment
+	os.Remove(crtF.Name())
+	os.Remove(keyF.Name())
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase))
+	require.NoError(t, signCert(args, ob, eb, testpw))
+	assert.Empty(t, eb.String())
+	os.Setenv("NEBULA_CA_PASSPHRASE", "")
+
 	// test with the wrong password
 	ob.Reset()
 	eb.Reset()
@@ -389,6 +398,17 @@ func Test_signCert(t *testing.T) {
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 
+	// test with the wrong password in environment
+	ob.Reset()
+	eb.Reset()
+
+	os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	os.Setenv("NEBULA_CA_PASSPHRASE", "")
+
 	// test with the user not entering a password
 	ob.Reset()
 	eb.Reset()

+ 13 - 0
cmd/nebula-service/main.go

@@ -4,6 +4,8 @@ import (
 	"flag"
 	"fmt"
 	"os"
+	"runtime/debug"
+	"strings"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
@@ -18,6 +20,17 @@ import (
 // at compile-time.
 var Build string
 
+func init() {
+	if Build == "" {
+		info, ok := debug.ReadBuildInfo()
+		if !ok {
+			return
+		}
+
+		Build = strings.TrimPrefix(info.Main.Version, "v")
+	}
+}
+
 func main() {
 	serviceFlag := flag.String("service", "", "Control the system service.")
 	configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")

+ 13 - 0
cmd/nebula/main.go

@@ -4,6 +4,8 @@ import (
 	"flag"
 	"fmt"
 	"os"
+	"runtime/debug"
+	"strings"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
@@ -18,6 +20,17 @@ import (
 // at compile-time.
 var Build string
 
+func init() {
+	if Build == "" {
+		info, ok := debug.ReadBuildInfo()
+		if !ok {
+			return
+		}
+
+		Build = strings.TrimPrefix(info.Main.Version, "v")
+	}
+}
+
 func main() {
 	configPath := flag.String("config", "", "Path to either a file or directory to load configuration from")
 	configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config")

+ 1 - 1
config/config.go

@@ -17,7 +17,7 @@ import (
 
 	"dario.cat/mergo"
 	"github.com/sirupsen/logrus"
-	"gopkg.in/yaml.v3"
+	"go.yaml.in/yaml/v3"
 )
 
 type C struct {

+ 1 - 1
config/config_test.go

@@ -10,7 +10,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v3"
+	"go.yaml.in/yaml/v3"
 )
 
 func TestConfig_Load(t *testing.T) {

+ 60 - 24
connection_manager.go

@@ -354,9 +354,8 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
 
 		if mainHostInfo {
 			decision = tryRehandshake
-
 		} else {
-			if cm.shouldSwapPrimary(hostinfo, primary) {
+			if cm.shouldSwapPrimary(hostinfo) {
 				decision = swapPrimary
 			} else {
 				// migrate the relays to the primary, if in use.
@@ -447,7 +446,7 @@ func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time
 	return inactiveDuration, true
 }
 
-func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
+func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
 	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
@@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool
 	}
 
 	crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	if crt == nil {
+		//my cert was reloaded away. We should definitely swap from this tunnel
+		return true
+	}
 	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
 	// settle down.
 	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
 	cm.hostMap.Unlock()
 }
 
-// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
-// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
-// check and return true.
+// isInvalidCertificate decides if we should destroy a tunnel.
+// returns true if pki.disconnect_invalid is true and the certificate is no longer valid.
+// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
 func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
-		return false
+		return false //don't tear down tunnels for handshakes in progress
 	}
 
 	caPool := cm.intf.pki.GetCAPool()
 	err := caPool.VerifyCachedCertificate(now, remoteCert)
 	if err == nil {
-		return false
-	}
-
-	if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
+		return false //cert is still valid! yay!
+	} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
 		// Block listed certificates should always be disconnected
+		hostinfo.logger(cm.l).WithError(err).
+			WithField("fingerprint", remoteCert.Fingerprint).
+			Info("Remote certificate is blocked, tearing down the tunnel")
+		return true
+	} else if cm.intf.disconnectInvalid.Load() {
+		hostinfo.logger(cm.l).WithError(err).
+			WithField("fingerprint", remoteCert.Fingerprint).
+			Info("Remote certificate is no longer valid, tearing down the tunnel")
+		return true
+	} else {
+		//if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open
 		return false
 	}
-
-	hostinfo.logger(cm.l).WithError(err).
-		WithField("fingerprint", remoteCert.Fingerprint).
-		Info("Remote certificate is no longer valid, tearing down the tunnel")
-
-	return true
 }
 
 func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
@@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
 func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 	cs := cm.intf.pki.getCertState()
 	curCrt := hostinfo.ConnectionState.myCert
-	myCrt := cs.getCertificate(curCrt.Version())
-	if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
-		// The current tunnel is using the latest certificate and version, no need to rehandshake.
+	curCrtVersion := curCrt.Version()
+	myCrt := cs.getCertificate(curCrtVersion)
+	if myCrt == nil {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("version", curCrtVersion).
+			WithField("reason", "local certificate removed").
+			Info("Re-handshaking with remote")
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 		return
 	}
+	peerCrt := hostinfo.ConnectionState.peerCert
+	if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
+		// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
+		if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
+			cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+				WithField("version", curCrtVersion).
+				WithField("peerVersion", peerCrt.Certificate.Version()).
+				WithField("reason", "local certificate version lower than peer, attempting to correct").
+				Info("Re-handshaking with remote")
+			cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
+				hh.initiatingVersionOverride = peerCrt.Certificate.Version()
+			})
+			return
+		}
+	}
+	if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("reason", "local certificate is not current").
+			Info("Re-handshaking with remote")
 
-	cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
-		WithField("reason", "local certificate is not current").
-		Info("Re-handshaking with remote")
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		return
+	}
+	if curCrtVersion < cs.initiatingVersion {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("reason", "current cert version < pki.initiatingVersion").
+			Info("Re-handshaking with remote")
 
-	cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		return
+	}
 }

+ 5 - 1
connection_manager_test.go

@@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse {
 		addrMap:   map[netip.Addr]*RemoteList{},
 		queryChan: make(chan netip.Addr, 10),
 	}
-	lighthouses := map[netip.Addr]struct{}{}
+	lighthouses := []netip.Addr{}
 	staticList := map[netip.Addr]struct{}{}
 
 	lh.lighthouses.Store(&lighthouses)
@@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte {
 	return d.publicKey
 }
 
+func (d *dummyCert) MarshalPublicKeyPEM() []byte {
+	return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey)
+}
+
 func (d *dummyCert) Signature() []byte {
 	return d.signature
 }

+ 1 - 6
connection_state.go

@@ -50,11 +50,6 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
 	}
 
 	static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
-
-	b := NewBits(ReplayWindow)
-	// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
-	b.Update(l, 0)
-
 	hs, err := noise.NewHandshakeState(noise.Config{
 		CipherSuite:   ncs,
 		Random:        rand.Reader,
@@ -74,7 +69,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
 	ci := &ConnectionState{
 		H:         hs,
 		initiator: initiator,
-		window:    b,
+		window:    NewBits(ReplayWindow),
 		myCert:    crt,
 	}
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.

+ 4 - 0
control_tester.go

@@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 }
 
+func (c *Control) GetF() *Interface {
+	return c.f
+}
+
 func (c *Control) GetCertState() *CertState {
 	return c.f.pki.getCertState()
 }

+ 209 - 2
e2e/handshakes_test.go

@@ -20,16 +20,17 @@ import (
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v3"
+	"go.yaml.in/yaml/v3"
 )
 
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 	// Start the servers
 	myControl.Start()
@@ -38,6 +39,41 @@ func BenchmarkHotPath(b *testing.B) {
 	r := router.NewR(b, myControl, theirControl)
 	r.CancelFlowLogs()
 
+	assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	b.ResetTimer()
+
+	for n := 0; n < b.N; n++ {
+		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+		_ = r.RouteForAllUntilTxTun(theirControl)
+	}
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func BenchmarkHotPathRelay(b *testing.B) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(b, myControl, relayControl, theirControl)
+	r.CancelFlowLogs()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	b.ResetTimer()
+
 	for n := 0; n < b.N; n++ {
 		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
@@ -45,6 +81,7 @@ func BenchmarkHotPath(b *testing.B) {
 
 	myControl.Stop()
 	theirControl.Stop()
+	relayControl.Stop()
 }
 
 func TestGoodHandshake(t *testing.T) {
@@ -97,6 +134,41 @@ func TestGoodHandshake(t *testing.T) {
 	theirControl.Stop()
 }
 
+func TestGoodHandshakeNoOverlap(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack!
+
+	// Put their info in our lighthouse
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	empty := []byte{}
+	t.Log("do something to cause a handshake")
+	myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty)
+
+	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
+	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
+
+	t.Log("Get their stage 1 packet")
+	stage1Packet := theirControl.GetFromUDP(true)
+
+	t.Log("Have me consume their stage 1 packet. I have a tunnel now")
+	myControl.InjectUDPPacket(stage1Packet)
+
+	t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete")
+	myControl.WaitForType(header.Test, 0, theirControl)
+
+	t.Log("Make sure our host infos are correct")
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
 func TestWrongResponderHandshake(t *testing.T) {
 	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 
@@ -464,6 +536,35 @@ func TestRelays(t *testing.T) {
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 }
 
+func TestRelaysDontCareAboutIps(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay  ", "2001::9999/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
+	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
+}
+
 func TestReestablishRelays(t *testing.T) {
 	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
@@ -1227,3 +1328,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh  ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}})
+
+	o := m{
+		"static_host_map": m{
+			lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()},
+		},
+		"lighthouse": m{
+			"hosts": []string{lhVpnIpNet[0].Addr().String()},
+			"local_allow_list": m{
+				// Try and block our lighthouse updates from using the actual addresses assigned to this computer
+				// If we start discovering addresses the test router doesn't know about then test traffic cant flow
+				"10.0.0.0/24": true,
+				"::/0":        false,
+			},
+		},
+	}
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me  ", "10.128.0.2/24, ff::2/64", o)
+	theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, lhControl, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	lhControl.Start()
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up an ipv6 tunnel between me and them")
+	assert.True(t, myVpnIpNet[1].Addr().Is6())
+	assert.True(t, theirVpnIpNet[1].Addr().Is6())
+	assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
+
+	lhControl.Stop()
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestGoodHandshakeUnsafeDest(t *testing.T) {
+	unsafePrefix := "192.168.6.0/24"
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil)
+	route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()}
+	myCfg := m{
+		"tun": m{
+			"unsafe_routes": []m{route},
+		},
+	}
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg)
+	t.Logf("my config %v", myConfig)
+	// Put their info in our lighthouse
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	spookyDest := netip.MustParseAddr("192.168.6.4")
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
+	myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+
+	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
+	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
+
+	t.Log("Get their stage 1 packet so that we can play with it")
+	stage1Packet := theirControl.GetFromUDP(true)
+
+	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
+	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
+	badPacket := stage1Packet.Copy()
+	badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len]
+	myControl.InjectUDPPacket(badPacket)
+
+	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
+	myControl.InjectUDPPacket(stage1Packet)
+
+	t.Log("Wait until we see my cached packet come through")
+	myControl.WaitForType(1, 0, theirControl)
+
+	t.Log("Make sure our host infos are correct")
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
+
+	t.Log("Get that cached packet and make sure it looks right")
+	myCachedPacket := theirControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80)
+
+	//reply
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))
+	//wait for reply
+	theirControl.WaitForType(1, 0, myControl)
+	theirCachedPacket := myControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80)
+
+	t.Log("Do a bidirectional tunnel test")
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 158 - 13
e2e/helpers_test.go

@@ -22,15 +22,14 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/stretchr/testify/assert"
-	"gopkg.in/yaml.v3"
+	"github.com/stretchr/testify/require"
+	"go.yaml.in/yaml/v3"
 )
 
 type m = map[string]any
 
 // newSimpleServer creates a nebula instance with many assumptions
 func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
-	l := NewTestLogger()
-
 	var vpnNetworks []netip.Prefix
 	for _, sn := range strings.Split(sVpnNetworks, ",") {
 		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
-	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
+	return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
+}
+
+func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides)
+}
+
+func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	l := NewTestLogger()
+
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
+	}
+
+	firewallInbound := []m{{
+		"proto": "any",
+		"port":  "any",
+		"host":  "any",
+	}}
+
+	var unsafeNetworks []netip.Prefix
+	if sUnsafeNetworks != "" {
+		firewallInbound = []m{{
+			"proto":      "any",
+			"port":       "any",
+			"host":       "any",
+			"local_cidr": "0.0.0.0/0",
+		}}
+
+		for _, sn := range strings.Split(sUnsafeNetworks, ",") {
+			x, err := netip.ParsePrefix(strings.TrimSpace(sn))
+			if err != nil {
+				panic(err)
+			}
+			unsafeNetworks = append(unsafeNetworks, x)
+		}
+	}
+
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{})
 
 	caB, err := caCrt.MarshalPEM()
 	if err != nil {
@@ -70,6 +116,104 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 			"key":  string(myPrivKey),
 		},
 		//"tun": m{"disabled": true},
+		"firewall": m{
+			"outbound": []m{{
+				"proto": "any",
+				"port":  "any",
+				"host":  "any",
+			}},
+			"inbound": firewallInbound,
+		},
+		//"handshakes": m{
+		//	"try_interval": "1s",
+		//},
+		"listen": m{
+			"host": udpAddr.Addr().String(),
+			"port": udpAddr.Port(),
+		},
+		"logging": m{
+			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
+			"level":            l.Level.String(),
+		},
+		"timers": m{
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
+		},
+	}
+
+	if overrides != nil {
+		final := m{}
+		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = final
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	c := config.NewC(l)
+	c.LoadString(string(cb))
+
+	control, err := nebula.Main(c, false, "e2e-test", l, nil)
+
+	if err != nil {
+		panic(err)
+	}
+
+	return control, vpnNetworks, udpAddr, c
+}
+
+// newServer creates a nebula instance with fewer assumptions
+func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	l := NewTestLogger()
+
+	vpnNetworks := certs[len(certs)-1].Networks()
+
+	var udpAddr netip.AddrPort
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
+		budpIp[1] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+	} else {
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
+	}
+
+	caStr := ""
+	for _, ca := range caCrt {
+		x, err := ca.MarshalPEM()
+		if err != nil {
+			panic(err)
+		}
+		caStr += string(x)
+	}
+	certStr := ""
+	for _, c := range certs {
+		x, err := c.MarshalPEM()
+		if err != nil {
+			panic(err)
+		}
+		certStr += string(x)
+	}
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": certStr,
+			"key":  string(key),
+		},
+		//"tun": m{"disabled": true},
 		"firewall": m{
 			"outbound": []m{{
 				"proto": "any",
@@ -90,7 +234,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 			"port": udpAddr.Port(),
 		},
 		"logging": m{
-			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
+			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()),
 			"level":            l.Level.String(),
 		},
 		"timers": m{
@@ -101,7 +245,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 
 	if overrides != nil {
 		final := m{}
-		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
+		err := mergo.Merge(&final, overrides, mergo.WithAppendSlice)
 		if err != nil {
 			panic(err)
 		}
@@ -118,7 +262,8 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 	}
 
 	c := config.NewC(l)
-	c.LoadString(string(cb))
+	cStr := string(cb)
+	c.LoadString(cStr)
 
 	control, err := nebula.Main(c, false, "e2e-test", l, nil)
 
@@ -147,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 	}
 }
 
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -163,10 +308,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
 	// Get both host infos
 	//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
 	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
+	require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
 	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
+	require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 	// Check that both vpn and real addr are correct
 	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
@@ -180,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
 }
 
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	if toIp.Is6() {
 		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
 	} else {
@@ -188,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	}
 }
 
-func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
 	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
 	assert.NotNil(t, v6, "No ipv6 data found")
@@ -207,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 
-func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")

+ 310 - 0
e2e/tunnels_test.go

@@ -4,12 +4,16 @@
 package e2e
 
 import (
+	"fmt"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/stretchr/testify/assert"
+	"gopkg.in/yaml.v3"
 )
 
 func TestDropInactiveTunnels(t *testing.T) {
@@ -55,3 +59,309 @@ func TestDropInactiveTunnels(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestCertUpgrade(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	caB, err := ca.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	ca2B, err := ca2.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
+
+	myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	_, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": string(myCert2Pem),
+			"key":  string(myPrivKey),
+		},
+		//"tun": m{"disabled": true},
+		"firewall": myC.Settings["firewall"],
+		//"handshakes": m{
+		//	"try_interval": "1s",
+		//},
+		"listen":  myC.Settings["listen"],
+		"logging": myC.Settings["logging"],
+		"timers":  myC.Settings["timers"],
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	r.Logf("reload new v2-only config")
+	err = myC.ReloadConfigString(string(cb))
+	assert.NoError(t, err)
+	r.Log("yay, spin until their sees it")
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		if c == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			r.Logf("version %d", version)
+			if version == cert.Version2 {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*10 {
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestCertDowngrade(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	caB, err := ca.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	ca2B, err := ca2.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
+
+	myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	//r.Log("yay")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": string(myCertPem),
+			"key":  string(myPrivKey),
+		},
+		"firewall": myC.Settings["firewall"],
+		"listen":   myC.Settings["listen"],
+		"logging":  myC.Settings["logging"],
+		"timers":   myC.Settings["timers"],
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	r.Logf("reload new v1-only config")
+	err = myC.ReloadConfigString(string(cb))
+	assert.NoError(t, err)
+	r.Log("yay, spin until their sees it")
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+		if c == nil || c2 == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			theirVersion := c2.Cert.Version()
+			r.Logf("version %d,%d", version, theirVersion)
+			if version == cert.Version1 {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*5 {
+			r.Log("it is unusual that the cert is not new yet, but not a failure yet")
+		}
+		if since > time.Second*10 {
+			r.Log("wtf")
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestCertMismatchCorrection(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	//r.Log("yay")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+		if c == nil || c2 == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			theirVersion := c2.Cert.Version()
+			r.Logf("version %d,%d", version, theirVersion)
+			if version == theirVersion {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*5 {
+			r.Log("wtf")
+		}
+		if since > time.Second*10 {
+			r.Log("wtf")
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestCrossStackRelaysWork(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me     ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay  ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
+	theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them   ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
+
+	//myVpnV4 := myVpnIpNet[0]
+	myVpnV6 := myVpnIpNet[1]
+	relayVpnV4 := relayVpnIpNet[0]
+	relayVpnV6 := relayVpnIpNet[1]
+	theirVpnV6 := theirVpnIpNet[0]
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
+
+	t.Log("reply?")
+	theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
+	p = r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
+
+	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
+	//t.Log("finish up")
+	//myControl.Stop()
+	//theirControl.Stop()
+	//relayControl.Stop()
+}

+ 3 - 2
examples/config.yml

@@ -424,8 +424,9 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
+  #     This can be used to filter destinations when using unsafe_routes.
   #     By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
   #     If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
   #   ca_name: An issuing CA name

+ 125 - 71
firewall.go

@@ -8,6 +8,7 @@ import (
 	"hash/fnv"
 	"net/netip"
 	"reflect"
+	"slices"
 	"strconv"
 	"strings"
 	"sync"
@@ -22,7 +23,7 @@ import (
 )
 
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
 }
 
 type conn struct {
@@ -247,22 +248,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
 }
 
 // AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
-	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
-	// https://github.com/golang/go/issues/14131
-	sIp := ""
-	if ip.IsValid() {
-		sIp = ip.String()
-	}
-	lIp := ""
-	if localIp.IsValid() {
-		lIp = localIp.String()
-	}
-
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
 	// We need this rule string because we generate a hash. Removing this will break firewall reload.
 	ruleString := fmt.Sprintf(
 		"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
-		incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
+		incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
 	)
 	f.rules += ruleString + "\n"
 
@@ -270,7 +260,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 	if !incoming {
 		direction = "outgoing"
 	}
-	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
+	f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
 		Info("Firewall rule added")
 
 	var (
@@ -297,7 +287,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		return fmt.Errorf("unknown protocol %v", proto)
 	}
 
-	return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
+	return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
 }
 
 // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -337,7 +327,6 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 	}
 
 	for i, t := range rs {
-		var groups []string
 		r, err := convertRule(l, t, table, i)
 		if err != nil {
 			return fmt.Errorf("%s rule #%v; %s", table, i, err)
@@ -347,23 +336,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
 		}
 
-		if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
+		if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
 			return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
 		}
 
-		if len(r.Groups) > 0 {
-			groups = r.Groups
-		}
-
-		if r.Group != "" {
-			// Check if we have both groups and group provided in the rule config
-			if len(groups) > 0 {
-				return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
-			}
-
-			groups = []string{r.Group}
-		}
-
 		var sPort, errPort string
 		if r.Code != "" {
 			errPort = "code"
@@ -392,23 +368,25 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 		}
 
-		var cidr netip.Prefix
-		if r.Cidr != "" {
-			cidr, err = netip.ParsePrefix(r.Cidr)
+		if r.Cidr != "" && r.Cidr != "any" {
+			_, err = netip.ParsePrefix(r.Cidr)
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
 			}
 		}
 
-		var localCidr netip.Prefix
-		if r.LocalCidr != "" {
-			localCidr, err = netip.ParsePrefix(r.LocalCidr)
+		if r.LocalCidr != "" && r.LocalCidr != "any" {
+			_, err = netip.ParsePrefix(r.LocalCidr)
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
 			}
 		}
 
-		err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
+		if warning := r.sanity(); warning != nil {
+			l.Warnf("%s rule #%v; %s", table, i, warning)
+		}
+
+		err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
 		if err != nil {
 			return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
 		}
@@ -417,8 +395,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 	return nil
 }
 
-var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
-var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
+var ErrUnknownNetworkType = errors.New("unknown network type")
+var ErrPeerRejected = errors.New("remote address is not within a network that we handle")
+var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks")
+var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
 var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
@@ -429,18 +409,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 		return nil
 	}
 
-	// Make sure remote address matches nebula certificate
-	if h.networks != nil {
-		if !h.networks.Contains(fp.RemoteAddr) {
+	// Make sure remote address matches nebula certificate, and determine how to treat it
+	if h.networks == nil {
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one address and no unsafe networks
-		if h.vpnAddrs[0] != fp.RemoteAddr {
+		nwType, ok := h.networks.Lookup(fp.RemoteAddr)
+		if !ok {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
+		switch nwType {
+		case NetworkTypeVPN:
+			break // nothing special
+		case NetworkTypeVPNPeer:
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
+			return ErrPeerRejected // reject for now, one day this may have different FW rules
+		case NetworkTypeUnsafe:
+			break // nothing special, one day this may have different FW rules
+		default:
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
+			return ErrUnknownNetworkType //should never happen
+		}
 	}
 
 	// Make sure we are supposed to be handling this local ip address
@@ -640,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
 	return false
 }
 
-func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 	}
@@ -653,7 +646,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
 			}
 		}
 
-		if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
+		if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
 			return err
 		}
 	}
@@ -684,7 +677,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 
-func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
 	fr := func() *FirewallRule {
 		return &FirewallRule{
 			Hosts:  make(map[string]*firewallLocalCIDR),
@@ -698,14 +691,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
 			fc.Any = fr()
 		}
 
-		return fc.Any.addRule(f, groups, host, ip, localIp)
+		return fc.Any.addRule(f, groups, host, cidr, localCidr)
 	}
 
 	if caSha != "" {
 		if _, ok := fc.CAShas[caSha]; !ok {
 			fc.CAShas[caSha] = fr()
 		}
-		err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
+		err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
 		if err != nil {
 			return err
 		}
@@ -715,7 +708,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
 		if _, ok := fc.CANames[caName]; !ok {
 			fc.CANames[caName] = fr()
 		}
-		err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
+		err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
 		if err != nil {
 			return err
 		}
@@ -747,24 +740,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
 	return fc.CANames[s.Certificate.Name()].match(p, c)
 }
 
-func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
 	flc := func() *firewallLocalCIDR {
 		return &firewallLocalCIDR{
 			LocalCIDR: new(bart.Lite),
 		}
 	}
 
-	if fr.isAny(groups, host, ip) {
+	if fr.isAny(groups, host, cidr) {
 		if fr.Any == nil {
 			fr.Any = flc()
 		}
 
-		return fr.Any.addRule(f, localCIDR)
+		return fr.Any.addRule(f, localCidr)
 	}
 
 	if len(groups) > 0 {
 		nlc := flc()
-		err := nlc.addRule(f, localCIDR)
+		err := nlc.addRule(f, localCidr)
 		if err != nil {
 			return err
 		}
@@ -780,30 +773,34 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, l
 		if nlc == nil {
 			nlc = flc()
 		}
-		err := nlc.addRule(f, localCIDR)
+		err := nlc.addRule(f, localCidr)
 		if err != nil {
 			return err
 		}
 		fr.Hosts[host] = nlc
 	}
 
-	if ip.IsValid() {
-		nlc, _ := fr.CIDR.Get(ip)
+	if cidr != "" {
+		c, err := netip.ParsePrefix(cidr)
+		if err != nil {
+			return err
+		}
+		nlc, _ := fr.CIDR.Get(c)
 		if nlc == nil {
 			nlc = flc()
 		}
-		err := nlc.addRule(f, localCIDR)
+		err = nlc.addRule(f, localCidr)
 		if err != nil {
 			return err
 		}
-		fr.CIDR.Insert(ip, nlc)
+		fr.CIDR.Insert(c, nlc)
 	}
 
 	return nil
 }
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
-	if len(groups) == 0 && host == "" && !ip.IsValid() {
+func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
+	if len(groups) == 0 && host == "" && cidr == "" {
 		return true
 	}
 
@@ -817,7 +814,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
 		return true
 	}
 
-	if ip.IsValid() && ip.Bits() == 0 {
+	if cidr == "any" {
 		return true
 	}
 
@@ -869,8 +866,13 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
 	return false
 }
 
-func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
-	if !localIp.IsValid() {
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
+	if localCidr == "any" {
+		flc.Any = true
+		return nil
+	}
+
+	if localCidr == "" {
 		if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
 			flc.Any = true
 			return nil
@@ -881,12 +883,13 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 		}
 		return nil
 
-	} else if localIp.Bits() == 0 {
-		flc.Any = true
-		return nil
 	}
 
-	flc.LocalCIDR.Insert(localIp)
+	c, err := netip.ParsePrefix(localCidr)
+	if err != nil {
+		return err
+	}
+	flc.LocalCIDR.Insert(c)
 	return nil
 }
 
@@ -907,7 +910,6 @@ type rule struct {
 	Code      string
 	Proto     string
 	Host      string
-	Group     string
 	Groups    []string
 	Cidr      string
 	LocalCidr string
@@ -949,7 +951,8 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
 		l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
 		m["group"] = v[0]
 	}
-	r.Group = toString("group", m)
+
+	singleGroup := toString("group", m)
 
 	if rg, ok := m["groups"]; ok {
 		switch reflect.TypeOf(rg).Kind() {
@@ -966,9 +969,60 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
 		}
 	}
 
+	//flatten group vs groups
+	if singleGroup != "" {
+		// Check if we have both groups and group provided in the rule config
+		if len(r.Groups) > 0 {
+			return r, fmt.Errorf("only one of group or groups should be defined, both provided")
+		}
+		r.Groups = []string{singleGroup}
+	}
+
 	return r, nil
 }
 
+// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
+// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
+func (r *rule) sanity() error {
+	//port, proto, local_cidr are AND, no need to check here
+	//ca_sha and ca_name don't have a wildcard value, no need to check here
+	groupsEmpty := len(r.Groups) == 0
+	hostEmpty := r.Host == ""
+	cidrEmpty := r.Cidr == ""
+
+	if (groupsEmpty && hostEmpty && cidrEmpty) == true {
+		return nil //no content!
+	}
+
+	groupsHasAny := slices.Contains(r.Groups, "any")
+	if groupsHasAny && len(r.Groups) > 1 {
+		return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
+	}
+
+	if r.Host == "any" {
+		if !groupsEmpty {
+			return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
+		}
+
+		if !cidrEmpty {
+			return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
+		}
+	}
+
+	if groupsHasAny {
+		if !hostEmpty && r.Host != "any" {
+			return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
+		}
+		if !cidrEmpty {
+			return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
+		}
+	}
+
+	//todo alert on cidr-any
+
+	return nil
+}
+
 func parsePort(s string) (startPort, endPort int32, err error) {
 	if s == "any" {
 		startPort = firewall.PortAny

+ 586 - 54
firewall_test.go

@@ -8,6 +8,8 @@ import (
 	"testing"
 	"time"
 
+	"github.com/gaissmai/bart"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
@@ -68,66 +70,117 @@ func TestFirewall_AddRule(t *testing.T) {
 	ti, err := netip.ParsePrefix("1.2.3.4/32")
 	require.NoError(t, err)
 
-	require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	ti6, err := netip.ParsePrefix("fd12::34/128")
+	require.NoError(t, err)
+
+	require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
 	assert.Nil(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
 	assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
 	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
 	_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
+	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
+	_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
+	assert.True(t, ok)
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
+	ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
+	assert.True(t, ok)
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
-	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
+	ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	anyIp, err := netip.ParsePrefix("0.0.0.0/0")
 	require.NoError(t, err)
 
-	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
+	table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
+	assert.True(t, table.Any)
+	table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
+	assert.False(t, ok)
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	anyIp6, err := netip.ParsePrefix("::/0")
+	require.NoError(t, err)
+
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
+	table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
+	assert.True(t, table.Any)
+	table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
+	assert.False(t, ok)
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
+	assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
+	assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
+	assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
+	assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
+
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
-	require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
+	require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
 }
 
 func TestFirewall_Drop(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
-
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
 		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
@@ -152,10 +205,10 @@ func TestFirewall_Drop(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
-	h.buildNetworks(c.networks, c.unsafeNetworks)
+	h.buildNetworks(myVpnNetworksTable, &c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -174,28 +227,107 @@ func TestFirewall_Drop(t *testing.T) {
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
 	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+}
+
+func TestFirewall_DropV6(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
+
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("fd12::34"),
+		RemoteAddr: netip.MustParseAddr("fd12::34"),
+		LocalPort:  10,
+		RemotePort: 90,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+
+	c := dummyCert{
+		name:     "host1",
+		networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
+		},
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
+	}
+	h.buildNetworks(myVpnNetworksTable, &c)
+
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
+	cp := cert.NewCAPool()
+
+	// Drop outbound
+	assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
+	// Allow inbound
+	resetConntrack(fw)
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	// Allow outbound because conntrack
+	require.NoError(t, fw.Drop(p, false, &h, cp, nil))
+
+	// test remote mismatch
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("fd12::56")
+	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
+	p.RemoteAddr = oldRemote
+
+	// ensure signer doesn't get in the way of group checks
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
+	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
+
+	// test caSha doesn't drop on match
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+
+	// ensure ca name doesn't get in the way of group checks
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
+	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
+
+	// test caName doesn't drop on match
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
 	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 
@@ -206,8 +338,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	}
 
 	pfix := netip.MustParsePrefix("172.1.1.1/32")
-	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
-	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
+
+	pfix6 := netip.MustParsePrefix("fd11::11/128")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -239,6 +375,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
+		ip := netip.MustParsePrefix("fd99::99/128")
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -252,6 +397,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
+			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
+		}
+	})
 
 	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -265,6 +422,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
+			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass on group on any local cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -289,6 +458,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass on group on specific local cidr6", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
+			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass on name", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -307,6 +487,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -332,7 +514,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate)
 
 	c1 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -347,10 +529,10 @@ func TestFirewall_Drop2(t *testing.T) {
 			peerCert: &c1,
 		},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
 	cp := cert.NewCAPool()
 
 	// h1/c1 lacks the proper groups
@@ -364,6 +546,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -395,7 +579,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
 
 	c2 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -410,7 +594,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
+	h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
 
 	c3 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -425,11 +609,11 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
+	h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
 	cp := cert.NewCAPool()
 
 	// c1 should pass because host match
@@ -443,14 +627,54 @@ func TestFirewall_Drop3(t *testing.T) {
 
 	// Test a remote address match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
 	require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 }
 
+func TestFirewall_Drop3V6(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
+
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("fd12::34"),
+		RemoteAddr: netip.MustParseAddr("fd12::34"),
+		LocalPort:  1,
+		RemotePort: 1,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+
+	network := netip.MustParsePrefix("fd12::34/120")
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{network},
+		},
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+		vpnAddrs: []netip.Addr{network.Addr()},
+	}
+	h.buildNetworks(myVpnNetworksTable, c.Certificate)
+
+	// Test a remote address match
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+	cp := cert.NewCAPool()
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+}
+
 func TestFirewall_DropConntrackReload(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
 
 	p := firewall.Packet{
 		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
@@ -477,10 +701,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+	h.buildNetworks(myVpnNetworksTable, c.Certificate)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -493,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -502,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
-	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -510,6 +734,52 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
 }
 
+func TestFirewall_DropIPSpoofing(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+	myVpnNetworksTable := new(bart.Lite)
+	myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
+
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
+		},
+	}
+
+	c1 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:           "host",
+			networks:       []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
+			unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
+		},
+	}
+	h1 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c1,
+		},
+		vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
+	}
+	h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
+
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
+	cp := cert.NewCAPool()
+
+	// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("192.0.2.1"),
+		RemoteAddr: netip.MustParseAddr("192.0.2.3"),
+		LocalPort:  1,
+		RemotePort: 1,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+	assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
+}
+
 func BenchmarkLookup(b *testing.B) {
 	ml := func(m map[string]struct{}, a [][]string) {
 		for n := 0; n < b.N; n++ {
@@ -689,28 +959,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
 
 	// Test adding udp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
 
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
 
 	// Test adding any rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
 
 	// Test adding rule with cidr
 	cidr := netip.MustParsePrefix("10.0.0.0/8")
@@ -718,49 +988,90 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
 
 	// Test adding rule with local_cidr
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
+
+	// Test adding rule with cidr ipv6
+	cidr6 := netip.MustParsePrefix("fd00::/8")
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
+
+	// Test adding rule with any cidr
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
+
+	// Test adding rule with junk cidr
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
+	require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
+
+	// Test adding rule with local_cidr ipv6
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
+
+	// Test adding rule with any local_cidr
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
+
+	// Test adding rule with junk local_cidr
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
+	require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
 
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
 
 	// Test single group
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
 
 	// Test single groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
 
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
 
 	// Test Add error
 	conf = config.NewC(l)
@@ -783,7 +1094,7 @@ func TestFirewall_convertRule(t *testing.T) {
 	r, err := convertRule(l, c, "test", 1)
 	assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
 	require.NoError(t, err)
-	assert.Equal(t, "group1", r.Group)
+	assert.Equal(t, []string{"group1"}, r.Groups)
 
 	// Ensure group array of > 1 is errord
 	ob.Reset()
@@ -803,7 +1114,228 @@ func TestFirewall_convertRule(t *testing.T) {
 
 	r, err = convertRule(l, c, "test", 1)
 	require.NoError(t, err)
-	assert.Equal(t, "group1", r.Group)
+	assert.Equal(t, []string{"group1"}, r.Groups)
+}
+
+func TestFirewall_convertRuleSanity(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	noWarningPlease := []map[string]any{
+		{"group": "group1"},
+		{"groups": []any{"group2"}},
+		{"host": "bob"},
+		{"cidr": "1.1.1.1/1"},
+		{"groups": []any{"group2"}, "host": "bob"},
+		{"cidr": "1.1.1.1/1", "host": "bob"},
+		{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
+		{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
+	}
+	for _, c := range noWarningPlease {
+		r, err := convertRule(l, c, "test", 1)
+		require.NoError(t, err)
+		require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
+	}
+
+	yesWarningPlease := []map[string]any{
+		{"group": "group1"},
+		{"groups": []any{"group2"}},
+		{"cidr": "1.1.1.1/1"},
+		{"groups": []any{"group2"}, "host": "bob"},
+		{"cidr": "1.1.1.1/1", "host": "bob"},
+		{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
+		{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
+	}
+	for _, c := range yesWarningPlease {
+		c["host"] = "any"
+		r, err := convertRule(l, c, "test", 1)
+		require.NoError(t, err)
+		err = r.sanity()
+		require.Error(t, err, "I wanted a warning: %+v", c)
+	}
+	//reset the list
+	yesWarningPlease = []map[string]any{
+		{"group": "group1"},
+		{"groups": []any{"group2"}},
+		{"cidr": "1.1.1.1/1"},
+		{"groups": []any{"group2"}, "host": "bob"},
+		{"cidr": "1.1.1.1/1", "host": "bob"},
+		{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
+		{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
+	}
+	for _, c := range yesWarningPlease {
+		r, err := convertRule(l, c, "test", 1)
+		require.NoError(t, err)
+		r.Groups = append(r.Groups, "any")
+		err = r.sanity()
+		require.Error(t, err, "I wanted a warning: %+v", c)
+	}
+}
+
+type testcase struct {
+	h   *HostInfo
+	p   firewall.Packet
+	c   cert.Certificate
+	err error
+}
+
+func (c *testcase) Test(t *testing.T, fw *Firewall) {
+	t.Helper()
+	cp := cert.NewCAPool()
+	resetConntrack(fw)
+	err := fw.Drop(c.p, true, c.h, cp, nil)
+	if c.err == nil {
+		require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
+	} else {
+		require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
+	}
+}
+
+func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
+	c1 := dummyCert{
+		name:     "host1",
+		networks: theirPrefixes,
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c1,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
+		},
+		vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
+	}
+	for i := range theirPrefixes {
+		h.vpnAddrs[i] = theirPrefixes[i].Addr()
+	}
+	h.buildNetworks(setup.myVpnNetworksTable, &c1)
+	p := firewall.Packet{
+		LocalAddr:  setup.c.Networks()[0].Addr(), //todo?
+		RemoteAddr: theirPrefixes[0].Addr(),
+		LocalPort:  10,
+		RemotePort: 90,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+	return testcase{
+		h:   &h,
+		p:   p,
+		c:   &c1,
+		err: err,
+	}
+}
+
+type testsetup struct {
+	c                  dummyCert
+	myVpnNetworksTable *bart.Lite
+	fw                 *Firewall
+}
+
+func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
+	c := dummyCert{
+		name:     "me",
+		networks: myPrefixes,
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+
+	return newSetupFromCert(t, l, c)
+}
+
+func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
+	myVpnNetworksTable := new(bart.Lite)
+	for _, prefix := range c.Networks() {
+		myVpnNetworksTable.Insert(prefix)
+	}
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
+
+	return testsetup{
+		c:                  c,
+		fw:                 fw,
+		myVpnNetworksTable: myVpnNetworksTable,
+	}
+}
+
+func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
+	t.Parallel()
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	myPrefix := netip.MustParsePrefix("1.1.1.1/8")
+	// for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
+	t.Run("allow inbound all matching", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
+		tc.Test(t, setup.fw)
+	})
+	t.Run("allow inbound local matching", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
+		tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
+		tc.Test(t, setup.fw)
+	})
+	t.Run("block inbound remote mismatched", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
+		tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
+		tc.Test(t, setup.fw)
+	})
+	t.Run("Block a vpn peer packet", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
+		tc.Test(t, setup.fw)
+	})
+	twoPrefixes := []netip.Prefix{
+		netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
+	}
+	t.Run("allow inbound one matching", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, nil, twoPrefixes...)
+		tc.Test(t, setup.fw)
+	})
+	t.Run("block inbound multimismatch", func(t *testing.T) {
+		t.Parallel()
+		setup := newSetup(t, l, myPrefix)
+		tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
+		tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
+		tc.Test(t, setup.fw)
+	})
+	t.Run("allow inbound 2nd one matching", func(t *testing.T) {
+		t.Parallel()
+		setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
+		tc := buildTestCase(setup2, nil, twoPrefixes...)
+		tc.p.RemoteAddr = twoPrefixes[1].Addr()
+		tc.Test(t, setup2.fw)
+	})
+	t.Run("allow inbound unsafe route", func(t *testing.T) {
+		t.Parallel()
+		unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
+		c := dummyCert{
+			name:           "me",
+			networks:       []netip.Prefix{myPrefix},
+			unsafeNetworks: []netip.Prefix{unsafePrefix},
+			groups:         []string{"default-group"},
+			issuer:         "signer-shasum",
+		}
+		unsafeSetup := newSetupFromCert(t, l, c)
+		tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
+		tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
+		tc.err = ErrNoMatchingRule
+		tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
+		require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
+		tc.err = nil
+		tc.Test(t, unsafeSetup.fw) //should pass
+	})
 }
 
 type addRuleCall struct {
@@ -813,8 +1345,8 @@ type addRuleCall struct {
 	endPort   int32
 	groups    []string
 	host      string
-	ip        netip.Prefix
-	localIp   netip.Prefix
+	ip        string
+	localIp   string
 	caName    string
 	caSha     string
 }
@@ -824,7 +1356,7 @@ type mockFirewall struct {
 	nextCallReturn error
 }
 
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
 	mf.lastCall = addRuleCall{
 		incoming:  incoming,
 		proto:     proto,

+ 19 - 19
go.mod

@@ -1,8 +1,6 @@
 module github.com/slackhq/nebula
 
-go 1.23.0
-
-toolchain go1.24.1
+go 1.25
 
 require (
 	dario.cat/mergo v1.0.2
@@ -10,30 +8,31 @@ require (
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
-	github.com/gaissmai/bart v0.20.4
+	github.com/gaissmai/bart v0.26.0
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
-	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.65
+	github.com/kardianos/service v1.2.4
+	github.com/miekg/dns v1.1.68
 	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.22.0
+	github.com/prometheus/client_golang v1.23.2
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
-	github.com/stretchr/testify v1.10.0
+	github.com/stretchr/testify v1.11.1
 	github.com/vishvananda/netlink v1.3.1
-	golang.org/x/crypto v0.37.0
+	go.yaml.in/yaml/v3 v3.0.4
+	golang.org/x/crypto v0.45.0
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
-	golang.org/x/net v0.39.0
-	golang.org/x/sync v0.13.0
-	golang.org/x/sys v0.32.0
-	golang.org/x/term v0.31.0
+	golang.org/x/net v0.47.0
+	golang.org/x/sync v0.18.0
+	golang.org/x/sys v0.38.0
+	golang.org/x/term v0.37.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.36.6
+	google.golang.org/protobuf v1.36.10
 	gopkg.in/yaml.v3 v3.0.1
 	gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
 )
@@ -45,11 +44,12 @@ require (
 	github.com/google/btree v1.1.2 // indirect
 	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.6.1 // indirect
-	github.com/prometheus/common v0.62.0 // indirect
-	github.com/prometheus/procfs v0.15.1 // indirect
+	github.com/prometheus/client_model v0.6.2 // indirect
+	github.com/prometheus/common v0.66.1 // indirect
+	github.com/prometheus/procfs v0.16.1 // indirect
 	github.com/vishvananda/netns v0.0.5 // indirect
-	golang.org/x/mod v0.23.0 // indirect
+	go.yaml.in/yaml/v2 v2.4.2 // indirect
+	golang.org/x/mod v0.24.0 // indirect
 	golang.org/x/time v0.5.0 // indirect
-	golang.org/x/tools v0.30.0 // indirect
+	golang.org/x/tools v0.33.0 // indirect
 )

+ 38 - 33
go.sum

@@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
-github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U=
-github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk=
+github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
+github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -64,8 +64,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/
 github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
 github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
 github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
-github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
-github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
+github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk=
+github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc=
 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
 github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
@@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc=
-github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
+github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
+github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q=
-github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0=
+github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
+github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
-github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
+github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
+github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
-github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
+github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
+github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
-github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
+github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
+github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -143,29 +143,35 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
 github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
 github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
 github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
 github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
+go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
+go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
+go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
+go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
+go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
 golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
-golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
+golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
+golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
-golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
+golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
+golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -176,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
-golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
+golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
+golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -185,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
-golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
+golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -197,18 +203,17 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
-golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
+golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
-golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
+golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
+golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -219,8 +224,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
-golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
+golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
+golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -239,8 +244,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
-google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
+google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
+google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

+ 146 - 154
handshake_ix.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"net/netip"
-	"slices"
 	"time"
 
 	"github.com/flynn/noise"
@@ -24,13 +23,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 		return false
 	}
 
-	// If we're connecting to a v6 address we must use a v2 cert
 	cs := f.pki.getCertState()
 	v := cs.initiatingVersion
-	for _, a := range hh.hostinfo.vpnAddrs {
-		if a.Is6() {
-			v = cert.Version2
-			break
+	if hh.initiatingVersionOverride != cert.VersionPre1 {
+		v = hh.initiatingVersionOverride
+	} else if v < cert.Version2 {
+		// If we're connecting to a v6 address we should encourage use of a V2 cert
+		for _, a := range hh.hostinfo.vpnAddrs {
+			if a.Is6() {
+				v = cert.Version2
+				break
+			}
 		}
 	}
 
@@ -49,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
 			WithField("certVersion", v).
 			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return false
 	}
 
 	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
@@ -105,19 +109,20 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	return true
 }
 
-func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
 	cs := f.pki.getCertState()
 	crt := cs.GetDefaultCertificate()
 	if crt == nil {
-		f.l.WithField("udpAddr", addr).
+		f.l.WithField("from", via).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
 			WithField("certVersion", cs.initiatingVersion).
 			Error("Unable to handshake with host because no certificate is available")
+		return
 	}
 
 	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Error("Failed to create connection state")
 		return
@@ -128,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Error("Failed to call noise.ReadMessage")
 		return
@@ -137,7 +142,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Error("Failed unmarshal handshake message")
 		return
@@ -145,7 +150,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("Handshake did not contain a certificate")
 		return
@@ -153,12 +158,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
 	if err != nil {
-		fp, err := rc.Fingerprint()
-		if err != nil {
+		fp, fperr := rc.Fingerprint()
+		if fperr != nil {
 			fp = "<error generating certificate fingerprint>"
 		}
 
-		e := f.l.WithError(err).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			WithField("certVpnNetworks", rc.Networks()).
 			WithField("certFingerprint", fp)
@@ -173,37 +178,40 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	if remoteCert.Certificate.Version() != ci.myCert.Version() {
 		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
-		rc := cs.getCertificate(remoteCert.Certificate.Version())
-		if rc == nil {
-			f.l.WithError(err).WithField("udpAddr", addr).
-				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
-				Info("Unable to handshake with host due to missing certificate version")
-			return
+		myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
+		if myCertOtherVersion == nil {
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithError(err).WithFields(m{
+					"from":      via,
+					"handshake": m{"stage": 1, "style": "ix_psk0"},
+					"cert":      remoteCert,
+				}).Debug("Might be unable to handshake with host due to missing certificate version")
+			}
+		} else {
+			// Record the certificate we are actually using
+			ci.myCert = myCertOtherVersion
 		}
-
-		// Record the certificate we are actually using
-		ci.myCert = rc
 	}
 
 	if len(remoteCert.Certificate.Networks()) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("cert", remoteCert).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("No networks in certificate")
 		return
 	}
 
-	var vpnAddrs []netip.Addr
-	var filteredNetworks []netip.Prefix
 	certName := remoteCert.Certificate.Name()
 	certVersion := remoteCert.Certificate.Version()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
+	vpnNetworks := remoteCert.Certificate.Networks()
 
-	for _, network := range remoteCert.Certificate.Networks() {
-		vpnAddr := network.Addr()
-		if f.myVpnAddrsTable.Contains(vpnAddr) {
-			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+	anyVpnAddrsInCommon := false
+	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
+	for i, network := range vpnNetworks {
+		if f.myVpnAddrsTable.Contains(network.Addr()) {
+			f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -211,38 +219,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
 			return
 		}
-
-		// vpnAddrs outside our vpn networks are of no use to us, filter them out
-		if !f.myVpnNetworksTable.Contains(vpnAddr) {
-			continue
+		vpnAddrs[i] = network.Addr()
+		if f.myVpnNetworksTable.Contains(network.Addr()) {
+			anyVpnAddrsInCommon = true
 		}
-
-		filteredNetworks = append(filteredNetworks, network)
-		vpnAddrs = append(vpnAddrs, vpnAddr)
 	}
 
-	if len(vpnAddrs) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("certVersion", certVersion).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
-		return
-	}
-
-	if addr.IsValid() {
-		// addr can be invalid when the tunnel is being relayed.
+	if !via.IsRelayed {
 		// We only want to apply the remote allow list for direct tunnels here
-		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
+				Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
 
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -265,10 +259,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			TotalPorts:  uint32(f.multiPort.TxPorts),
 		}
 	}
-	if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(addr.Port()) {
+	if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(via.UdpAddr.Port()) {
 		// The other side sent us a handshake from a different port, make sure
 		// we send responses back to the BasePort
-		addr = netip.AddrPortFrom(addr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort))
+		via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort))
 	}
 
 	hostinfo := &HostInfo{
@@ -287,27 +281,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		},
 	}
 
-	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-		WithField("certName", certName).
-		WithField("certVersion", certVersion).
-		WithField("fingerprint", fingerprint).
-		WithField("issuer", issuer).
-		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-		WithField("multiportTx", multiportTx).WithField("multiportRx", multiportRx).
-		Info("Handshake message received")
+	msgRxL := f.l.WithFields(m{
+		"vpnAddrs":       vpnAddrs,
+		"from":           via,
+		"certName":       certName,
+		"certVersion":    certVersion,
+		"fingerprint":    fingerprint,
+		"issuer":         issuer,
+		"initiatorIndex": hs.Details.InitiatorIndex,
+		"responderIndex": hs.Details.ResponderIndex,
+		"remoteIndex":    h.RemoteIndex,
+		"multiportTx":    multiportTx,
+		"multiportRx":    multiportRx,
+		"handshake":      m{"stage": 1, "style": "ix_psk0"},
+	})
+
+	if anyVpnAddrsInCommon {
+		msgRxL.Info("Handshake message received")
+	} else {
+		//todo warn if not lighthouse or relay?
+		msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
+	}
 
 	hs.Details.ResponderIndex = myIndex
 	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
 	if hs.Details.Cert == nil {
-		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("certVersion", certVersion).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			WithField("certVersion", ci.myCert.Version()).
+		msgRxL.WithField("myCertVersion", ci.myCert.Version()).
 			Error("Unable to handshake with host because no certificate handshake bytes is available")
 		return
 	}
@@ -318,7 +317,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -330,7 +329,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -338,7 +337,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -364,8 +363,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	ci.eKey = NewNebulaCipherState(eKey)
 
 	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
-	hostinfo.SetRemote(addr)
-	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
+	if !via.IsRelayed {
+		hostinfo.SetRemote(via.UdpAddr)
+	}
+	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -373,10 +374,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		case ErrAlreadySeen:
 			if hostinfo.multiportRx {
 				// The other host is sending to us with multiport, so only grab the IP
-				addr = netip.AddrPortFrom(addr.Addr(), hostinfo.remote.Port())
+				via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
 			}
 			// Update remote if preferred
-			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
+			if existing.SetRemoteIfPreferred(f.hostMap, via) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -384,28 +385,29 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-			if addr.IsValid() {
+			if !via.IsRelayed {
+				err := f.outside.WriteTo(msg, via.UdpAddr)
 				if multiportTx {
 					// TODO remove alloc here
 					raw := make([]byte, len(msg)+udp.RawOverhead)
 					copy(raw[udp.RawOverhead:], msg)
-					err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), addr)
+					err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr)
 				} else {
-					err = f.outside.WriteTo(msg, addr)
+					err = f.outside.WriteTo(msg, via.UdpAddr)
 				}
 				if err != nil {
-					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 				} else {
-					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 				}
 				return
 			} else {
-				if via == nil {
-					f.l.Error("Handshake send failed: both addr and via are nil.")
+				if via.relay == nil {
+					f.l.Error("Handshake send failed: both addr and via.relay are nil.")
 					return
 				}
 				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -417,7 +419,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			}
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
@@ -433,7 +435,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -446,7 +448,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -460,37 +462,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	// Do the send
 	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-	if addr.IsValid() {
+	if !via.IsRelayed {
 		if multiportTx {
 			// TODO remove alloc here
 			raw := make([]byte, len(msg)+udp.RawOverhead)
 			copy(raw[udp.RawOverhead:], msg)
-			err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), addr)
+			err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr)
 		} else {
-			err = f.outside.WriteTo(msg, addr)
+			err = f.outside.WriteTo(msg, via.UdpAddr)
 		}
+		log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
+			WithField("certName", certName).
+			WithField("certVersion", certVersion).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 		if err != nil {
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-				WithField("certName", certName).
-				WithField("certVersion", certVersion).
-				WithField("fingerprint", fingerprint).
-				WithField("issuer", issuer).
-				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				WithError(err).Error("Failed to send handshake")
+			log.WithError(err).Error("Failed to send handshake")
 		} else {
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-				WithField("certName", certName).
-				WithField("certVersion", certVersion).
-				WithField("fingerprint", fingerprint).
-				WithField("issuer", issuer).
-				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				Info("Handshake message sent")
+			log.Info("Handshake message sent")
 		}
 	} else {
-		if via == nil {
-			f.l.Error("Handshake send failed: both addr and via are nil.")
+		if via.relay == nil {
+			f.l.Error("Handshake send failed: both addr and via.relay are nil.")
 			return
 		}
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -510,12 +505,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	f.connectionManager.AddTrafficWatch(hostinfo)
 
-	hostinfo.remotes.ResetBlockedRemotes()
+	hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
 
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
 	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -525,10 +520,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	defer hh.Unlock()
 
 	hostinfo := hh.hostinfo
-	if addr.IsValid() {
+	if !via.IsRelayed {
 		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
-		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
-			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
 	}
@@ -536,7 +531,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -545,7 +540,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 
@@ -557,7 +552,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -569,18 +564,18 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx
 	}
 
-	if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(addr.Port()) {
+	if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(via.UdpAddr.Port()) {
 		// The other side sent us a handshake from a different port, make sure
 		// we send responses back to the BasePort
-		addr = netip.AddrPortFrom(
-			addr.Addr(),
+		via.UdpAddr = netip.AddrPortFrom(
+			via.UdpAddr.Addr(),
 			uint16(hs.Details.ResponderMultiPort.BasePort),
 		)
 	}
 
 	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("vpnAddrs", hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Handshake did not contain a certificate")
@@ -594,7 +589,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			fp = "<error generating certificate fingerprint>"
 		}
 
-		e := f.l.WithError(err).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("from", via).
 			WithField("vpnAddrs", hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("certFingerprint", fp).
@@ -609,7 +604,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	}
 
 	if len(remoteCert.Certificate.Networks()) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("vpnAddrs", hostinfo.vpnAddrs).
 			WithField("cert", remoteCert).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -632,39 +627,30 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.eKey = NewNebulaCipherState(eKey)
 
 	// Make sure the current udpAddr being used is set for responding
-	if addr.IsValid() {
-		hostinfo.SetRemote(addr)
+	if !via.IsRelayed {
+		hostinfo.SetRemote(via.UdpAddr)
 	} else {
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 	}
 
-	var vpnAddrs []netip.Addr
-	var filteredNetworks []netip.Prefix
-	for _, network := range vpnNetworks {
-		// vpnAddrs outside our vpn networks are of no use to us, filter them out
-		vpnAddr := network.Addr()
-		if !f.myVpnNetworksTable.Contains(vpnAddr) {
-			continue
+	correctHostResponded := false
+	anyVpnAddrsInCommon := false
+	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
+	for i, network := range vpnNetworks {
+		vpnAddrs[i] = network.Addr()
+		if f.myVpnNetworksTable.Contains(network.Addr()) {
+			anyVpnAddrsInCommon = true
+		}
+		if hostinfo.vpnAddrs[0] == network.Addr() {
+			// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
+			correctHostResponded = true
 		}
-
-		filteredNetworks = append(filteredNetworks, network)
-		vpnAddrs = append(vpnAddrs, vpnAddr)
-	}
-
-	if len(vpnAddrs) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("certVersion", certVersion).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
-		return true
 	}
 
 	// Ensure the right host responded
-	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+	if !correctHostResponded {
 		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
-			WithField("udpAddr", addr).
+			WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -674,10 +660,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
+		//TODO is hostinfo.vpnAddrs[0] always the address to use?
 		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
-			newHH.hostinfo.remotes.BlockRemote(addr)
+			newHH.hostinfo.remotes.BlockRemote(via)
 
 			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
 				WithField("vpnNetworks", vpnNetworks).
@@ -700,7 +687,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+	msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 		WithField("certName", certName).
 		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).
@@ -709,12 +696,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 		WithField("durationNs", duration).
 		WithField("sentCachedPackets", len(hh.packetStore)).
-		WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx).
-		Info("Handshake message received")
+		WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx)
+	if anyVpnAddrsInCommon {
+		msgRxL.Info("Handshake message received")
+	} else {
+		//todo warn if not lighthouse or relay?
+		msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
+	}
 
 	// Build up the radix for the firewall if we have subnets in the cert
 	hostinfo.vpnAddrs = vpnAddrs
-	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
+	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
@@ -733,7 +725,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
 	}
 
-	hostinfo.remotes.ResetBlockedRemotes()
+	hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
 	f.metricHandshakes.Update(duration)
 
 	return false

+ 14 - 13
handshake_manager.go

@@ -71,11 +71,12 @@ type HandshakeManager struct {
 type HandshakeHostInfo struct {
 	sync.Mutex
 
-	startTime   time.Time        // Time that we first started trying with this handshake
-	ready       bool             // Is the handshake ready
-	counter     int64            // How many attempts have we made so far
-	lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
-	packetStore []*cachedPacket  // A set of packets to be transmitted once the handshake completes
+	startTime                 time.Time        // Time that we first started trying with this handshake
+	ready                     bool             // Is the handshake ready
+	initiatingVersionOverride cert.Version     // Should we use a non-default cert version for this handshake?
+	counter                   int64            // How many attempts have we made so far
+	lastRemotes               []netip.AddrPort // Remotes that we sent to during the previous attempt
+	packetStore               []*cachedPacket  // A set of packets to be transmitted once the handshake completes
 
 	hostinfo *HostInfo
 }
@@ -138,11 +139,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
 	}
 }
 
-func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
-	if addr.IsValid() {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
-			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if !via.IsRelayed {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
+			hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
@@ -151,11 +152,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
 	case header.HandshakeIXPSK0:
 		switch h.MessageCounter {
 		case 1:
-			ixHandshakeStage1(hm.f, addr, via, packet, h)
+			ixHandshakeStage1(hm.f, via, packet, h)
 
 		case 2:
 			newHostinfo := hm.queryIndex(h.RemoteIndex)
-			tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
+			tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
 			if tearDown && newHostinfo != nil {
 				hm.DeleteHostInfo(newHostinfo.hostinfo)
 			}
@@ -294,12 +295,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself
+			// Don't relay through the host I'm trying to connect to
 			if relay == vpnIp {
 				continue
 			}
 
-			// Don't relay through the host I'm trying to connect to
+			// Don't relay to myself
 			if hm.f.myVpnAddrsTable.Contains(relay) {
 				continue
 			}

+ 56 - 28
hostmap.go

@@ -1,7 +1,9 @@
 package nebula
 
 import (
+	"encoding/json"
 	"errors"
+	"fmt"
 	"net"
 	"net/netip"
 	"slices"
@@ -17,12 +19,10 @@ import (
 	"github.com/slackhq/nebula/header"
 )
 
-// const ProbeLen = 100
 const defaultPromoteEvery = 1000       // Count of packets sent before we try moving a tunnel to a preferred underlay ip address
 const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
 const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
-const maxRecvError = 4
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
 // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -214,6 +214,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.relayForByIdx[idx] = r
 }
 
+type NetworkType uint8
+
+const (
+	NetworkTypeUnknown NetworkType = iota
+	// NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate
+	NetworkTypeVPN
+	// NetworkTypeVPNPeer is a network that does not overlap one of our networks
+	NetworkTypeVPNPeer
+	// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
+	NetworkTypeUnsafe
+)
+
 type HostInfo struct {
 	remote          netip.AddrPort
 	remotes         *RemoteList
@@ -225,11 +237,10 @@ type HostInfo struct {
 	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
 	// The host may have other vpn addresses that are outside our
 	// vpn networks but were removed because they are not usable
-	vpnAddrs  []netip.Addr
-	recvError atomic.Uint32
+	vpnAddrs []netip.Addr
 
-	// networks are both all vpn and unsafe networks assigned to this host
-	networks   *bart.Lite
+	// networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host.
+	networks   *bart.Table[NetworkType]
 	relayState RelayState
 
 	// If true, we should send to this remote using multiport
@@ -273,9 +284,25 @@ type HostInfo struct {
 }
 
 type ViaSender struct {
+	UdpAddr   netip.AddrPort
 	relayHI   *HostInfo // relayHI is the host info object of the relay
 	remoteIdx uint32    // remoteIdx is the index included in the header of the received packet
 	relay     *Relay    // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
+	IsRelayed bool      // IsRelayed is true if the packet was sent through a relay
+}
+
+func (v ViaSender) String() string {
+	if v.IsRelayed {
+		return fmt.Sprintf("%s (relayed)", v.UdpAddr)
+	}
+	return v.UdpAddr.String()
+}
+
+func (v ViaSender) MarshalJSON() ([]byte, error) {
+	if v.IsRelayed {
+		return json.Marshal(m{"relay": v.UdpAddr})
+	}
+	return json.Marshal(m{"direct": v.UdpAddr})
 }
 
 type cachedPacket struct {
@@ -691,6 +718,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
 	return nil
 }
 
+// TODO: Maybe use ViaSender here?
 func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	if i.remote != remote {
@@ -701,14 +729,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
-	if !newRemote.IsValid() {
-		// relays have nil udp Addrs
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool {
+	if via.IsRelayed {
 		return false
 	}
+
 	currentRemote := i.remote
 	if !currentRemote.IsValid() {
-		i.SetRemote(newRemote)
+		i.SetRemote(via.UdpAddr)
 		return true
 	}
 
@@ -721,7 +749,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 			return false
 		}
 
-		if l.Contains(newRemote.Addr()) {
+		if l.Contains(via.UdpAddr.Addr()) {
 			newIsPreferred = true
 		}
 	}
@@ -731,7 +759,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 		i.lastRoam = time.Now()
 		i.lastRoamRemote = currentRemote
 
-		i.SetRemote(newRemote)
+		i.SetRemote(via.UdpAddr)
 
 		return true
 	}
@@ -739,26 +767,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 	return false
 }
 
-func (i *HostInfo) RecvErrorExceeded() bool {
-	if i.recvError.Add(1) >= maxRecvError {
-		return true
-	}
-	return true
-}
-
-func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
-	if len(networks) == 1 && len(unsafeNetworks) == 0 {
-		// Simple case, no CIDRTree needed
-		return
+// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up.
+func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) {
+	if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
+		if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) {
+			return // Simple case, no BART needed
+		}
 	}
 
-	i.networks = new(bart.Lite)
-	for _, network := range networks {
-		i.networks.Insert(network)
+	i.networks = new(bart.Table[NetworkType])
+	for _, network := range c.Networks() {
+		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
+		if myVpnNetworksTable.Contains(network.Addr()) {
+			i.networks.Insert(nprefix, NetworkTypeVPN)
+		} else {
+			i.networks.Insert(nprefix, NetworkTypeVPNPeer)
+		}
 	}
 
-	for _, network := range unsafeNetworks {
-		i.networks.Insert(network)
+	for _, network := range c.UnsafeNetworks() {
+		i.networks.Insert(network, NetworkTypeUnsafe)
 	}
 }
 

+ 6 - 5
inside.go

@@ -121,9 +121,10 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil)
 }
 
-// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
+// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established
+// it does not check if it is within our vpn networks!
 func (f *Interface) Handshake(vpnAddr netip.Addr) {
-	f.getOrHandshakeNoRouting(vpnAddr, nil)
+	f.handshakeManager.GetOrHandshake(vpnAddr, nil)
 }
 
 // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
@@ -139,7 +140,6 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu
 // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
 func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-
 	destinationAddr := fwPacket.RemoteAddr
 
 	hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
@@ -232,9 +232,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
 }
 
-// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr.
+// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr
 func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
+	hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 

+ 8 - 1
interface.go

@@ -234,6 +234,13 @@ func (f *Interface) activate() {
 		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
 
+	if f.routines > 1 {
+		if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
+			f.routines = 1
+			f.l.Warn("routines is not supported on this platform, falling back to a single routine")
+		}
+	}
+
 	metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
 
 	metrics.GetOrRegisterGauge("multiport.tx_ports", nil).Update(int64(f.multiPort.TxPorts))
@@ -286,7 +293,7 @@ func (f *Interface) listenOut(i int) {
 	nb := make([]byte, 12, 12)
 
 	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
-		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+		f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
 	})
 }
 

+ 133 - 115
lighthouse.go

@@ -24,6 +24,7 @@ import (
 )
 
 var ErrHostNotKnown = errors.New("host not known")
+var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr")
 
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time
@@ -56,7 +57,7 @@ type LightHouse struct {
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
 	staticList  atomic.Pointer[map[netip.Addr]struct{}]
-	lighthouses atomic.Pointer[map[netip.Addr]struct{}]
+	lighthouses atomic.Pointer[[]netip.Addr]
 
 	interval     atomic.Int64
 	updateCancel context.CancelFunc
@@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		queryChan:          make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
 		l:                  l,
 	}
-	lighthouses := make(map[netip.Addr]struct{})
+	lighthouses := make([]netip.Addr, 0)
 	h.lighthouses.Store(&lighthouses)
 	staticList := make(map[netip.Addr]struct{})
 	h.staticList.Store(&staticList)
@@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
 	return *lh.staticList.Load()
 }
 
-func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
+func (lh *LightHouse) GetLighthouses() []netip.Addr {
 	return *lh.lighthouses.Load()
 }
 
@@ -306,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 
 	if initial || c.HasChanged("lighthouse.hosts") {
-		lhMap := make(map[netip.Addr]struct{})
-		err := lh.parseLighthouses(c, lhMap)
+		lhList, err := lh.parseLighthouses(c)
 		if err != nil {
 			return err
 		}
 
-		lh.lighthouses.Store(&lhMap)
+		lh.lighthouses.Store(&lhList)
 		if !initial {
 			//NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic
 			lh.l.Info("lighthouse.hosts has changed")
@@ -346,36 +346,38 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) {
 	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
 	if lh.amLighthouse && len(lhs) != 0 {
 		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
+	out := make([]netip.Addr, len(lhs))
 
 	for i, host := range lhs {
 		addr, err := netip.ParseAddr(host)
 		if err != nil {
-			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
+			return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
 		}
 
 		if !lh.myVpnNetworksTable.Contains(addr) {
-			return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil)
+			lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}).
+				Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not")
 		}
-		lhMap[addr] = struct{}{}
+		out[i] = addr
 	}
 
-	if !lh.amLighthouse && len(lhMap) == 0 {
+	if !lh.amLighthouse && len(out) == 0 {
 		lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries")
 	}
 
 	staticList := lh.GetStaticHostList()
-	for lhAddr, _ := range lhMap {
-		if _, ok := staticList[lhAddr]; !ok {
-			return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr)
+	for i := range out {
+		if _, ok := staticList[out[i]]; !ok {
+			return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i])
 		}
 	}
 
-	return nil
+	return out, nil
 }
 
 func getStaticMapCadence(c *config.C) (time.Duration, error) {
@@ -430,7 +432,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc
 		}
 
 		if !lh.myVpnNetworksTable.Contains(vpnAddr) {
-			return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil)
+			lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}).
+				Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work")
 		}
 
 		vals, ok := v.([]any)
@@ -486,7 +489,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList {
 	lh.Lock()
 	defer lh.Unlock()
 	// Add an entry if we don't already have one
-	return lh.unlockedGetRemoteList(vpnAddrs)
+	return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip
 }
 
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
@@ -519,11 +522,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in
 }
 
 func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) {
-	// First we check the static mapping
-	// and do nothing if it is there
-	if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok {
-		return
+	// First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing.
+	staticList := lh.GetStaticHostList()
+	for _, addr := range allVpnAddrs {
+		if _, ok := staticList[addr]; ok {
+			return
+		}
 	}
+
+	// None of the VpnAddrs were present. Now we can do the deletes.
 	lh.Lock()
 	rm, ok := lh.addrMap[allVpnAddrs[0]]
 	if ok {
@@ -565,7 +572,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
 	am.unlockedSetHostnamesResults(hr)
 
 	for _, addrPort := range hr.GetAddrs() {
-		if !lh.shouldAdd(vpnAddr, addrPort.Addr()) {
+		if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) {
 			continue
 		}
 		switch {
@@ -627,23 +634,30 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool {
 	return len(calculatedV4) > 0 || len(calculatedV6) > 0
 }
 
-// unlockedGetRemoteList
-// assumes you have the lh lock
+// unlockedGetRemoteList assumes you have the lh lock
 func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList {
-	am, ok := lh.addrMap[allAddrs[0]]
-	if !ok {
-		am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) })
-		for _, addr := range allAddrs {
-			lh.addrMap[addr] = am
+	// before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet
+	for i, addr := range allAddrs {
+		am, ok := lh.addrMap[addr]
+		if ok {
+			if i != 0 {
+				lh.addrMap[allAddrs[0]] = am
+			}
+			return am
 		}
 	}
+
+	am := NewRemoteList(allAddrs, lh.shouldAdd)
+	for _, addr := range allAddrs {
+		lh.addrMap[addr] = am
+	}
 	return am
 }
 
-func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool {
-	allow := lh.GetRemoteAllowList().Allow(vpnAddr, to)
+func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool {
+	allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to)
 	if lh.l.Level >= logrus.TraceLevel {
-		lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow).
+		lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow).
 			Trace("remoteAllowList.Allow")
 	}
 	if !allow {
@@ -698,19 +712,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo
 }
 
 func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool {
-	if _, ok := lh.GetLighthouses()[vpnAddr]; ok {
-		return true
+	l := lh.GetLighthouses()
+	for i := range l {
+		if l[i] == vpnAddr {
+			return true
+		}
 	}
 	return false
 }
 
-// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake
-// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially
-func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool {
+func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool {
 	l := lh.GetLighthouses()
-	for _, a := range vpnAddr {
-		if _, ok := l[a]; ok {
-			return true
+	for i := range vpnAddrs {
+		for j := range l {
+			if l[j] == vpnAddrs[i] {
+				return true
+			}
 		}
 	}
 	return false
@@ -752,7 +769,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
 	queried := 0
 	lighthouses := lh.GetLighthouses()
 
-	for lhVpnAddr := range lighthouses {
+	for _, lhVpnAddr := range lighthouses {
 		hi := lh.ifce.GetHostInfo(lhVpnAddr)
 		if hi != nil {
 			v = hi.ConnectionState.myCert.Version()
@@ -870,7 +887,7 @@ func (lh *LightHouse) SendUpdate() {
 	updated := 0
 	lighthouses := lh.GetLighthouses()
 
-	for lhVpnAddr := range lighthouses {
+	for _, lhVpnAddr := range lighthouses {
 		var v cert.Version
 		hi := lh.ifce.GetHostInfo(lhVpnAddr)
 		if hi != nil {
@@ -928,7 +945,6 @@ func (lh *LightHouse) SendUpdate() {
 						V4AddrPorts:   v4,
 						V6AddrPorts:   v6,
 						RelayVpnAddrs: relays,
-						VpnAddr:       netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()),
 					},
 				}
 
@@ -1048,19 +1064,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 		return
 	}
 
-	useVersion := cert.Version1
-	var queryVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		queryVpnAddr = netip.AddrFrom4(b)
-		useVersion = 1
-	} else if n.Details.VpnAddr != nil {
-		queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		useVersion = 2
-	} else {
+	queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
+				Debugln("Dropping malformed HostQuery")
+		}
+		return
+	}
+	if useVersion == cert.Version1 && queryVpnAddr.Is6() {
+		// this case really shouldn't be possible to represent, but reject it anyway.
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
+			lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
+				Debugln("invalid vpn addr for v1 handleHostQuery")
 		}
 		return
 	}
@@ -1069,9 +1085,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		if useVersion == cert.Version1 {
-			if !queryVpnAddr.Is4() {
-				return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
-			}
 			b := queryVpnAddr.As4()
 			n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
 		} else {
@@ -1116,8 +1129,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd
 			if ok {
 				whereToPunch = newDest
 			} else {
-				//TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee
-				//choosing to do nothing for now, but maybe we return an error?
+				if lhh.l.Level >= logrus.DebugLevel {
+					lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common")
+				}
 			}
 		}
 
@@ -1176,19 +1190,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul
 				if !r.Is4() {
 					continue
 				}
-
 				b = r.As4()
 				n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:]))
 			}
-
 		} else if v == cert.Version2 {
 			for _, r := range c.relay.relay {
 				n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r))
 			}
-
 		} else {
-			//TODO: CERT-V2 don't panic
-			panic("unsupported version")
+			if lhh.l.Level >= logrus.DebugLevel {
+				lhh.l.WithField("version", v).Debug("unsupported protocol version")
+			}
 		}
 	}
 }
@@ -1198,18 +1210,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [
 		return
 	}
 
-	lhh.lh.Lock()
-
-	var certVpnAddr netip.Addr
-	if n.Details.OldVpnAddr != 0 {
-		b := [4]byte{}
-		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-		certVpnAddr = netip.AddrFrom4(b)
-	} else if n.Details.VpnAddr != nil {
-		certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
+	certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply")
+		}
+		return
 	}
 	relays := n.Details.GetRelays()
 
+	lhh.lh.Lock()
 	am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr})
 	am.Lock()
 	lhh.lh.Unlock()
@@ -1234,27 +1244,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 		return
 	}
 
+	// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
 	var detailsVpnAddr netip.Addr
-	useVersion := cert.Version1
-	if n.Details.OldVpnAddr != 0 {
+	var useVersion cert.Version
+	if n.Details.OldVpnAddr != 0 { //v1 always sets this field
 		b := [4]byte{}
 		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
 		detailsVpnAddr = netip.AddrFrom4(b)
 		useVersion = cert.Version1
-	} else if n.Details.VpnAddr != nil {
+	} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
 		detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
 		useVersion = cert.Version2
 	} else {
-		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification")
-		}
-		return
+		detailsVpnAddr = netip.Addr{}
+		useVersion = cert.Version2
 	}
 
-	//TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
-	//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
-	//Simple check that the host sent this not someone else
-	if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
+	//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
+	if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
 		}
@@ -1268,24 +1275,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 	am.Lock()
 	lhh.lh.Unlock()
 
-	am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
+	am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
 	am.unlockedSetRelay(fromVpnAddrs[0], relays)
 	am.Unlock()
 
 	n = lhh.resetMeta()
 	n.Type = NebulaMeta_HostUpdateNotificationAck
-
-	if useVersion == cert.Version1 {
+	switch useVersion {
+	case cert.Version1:
 		if !fromVpnAddrs[0].Is4() {
 			lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
 			return
 		}
 		vpnAddrB := fromVpnAddrs[0].As4()
 		n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
-	} else if useVersion == cert.Version2 {
-		n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
-	} else {
+	case cert.Version2:
+		// do nothing, we want to send a blank message
+	default:
 		lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
 		return
 	}
@@ -1303,13 +1310,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) {
 	//It's possible the lighthouse is communicating with us using a non primary vpn addr,
 	//which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs.
-	//maybe one day we'll have a better idea, if it matters.
 	if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) {
 		return
 	}
 
+	detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion()
+	if err != nil {
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification")
+		}
+		return
+	}
+
 	empty := []byte{0}
-	punch := func(vpnPeer netip.AddrPort) {
+	punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
 		if !vpnPeer.IsValid() {
 			return
 		}
@@ -1321,48 +1335,38 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
 		}()
 
 		if lhh.l.Level >= logrus.DebugLevel {
-			var logVpnAddr netip.Addr
-			if n.Details.OldVpnAddr != 0 {
-				b := [4]byte{}
-				binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-				logVpnAddr = netip.AddrFrom4(b)
-			} else if n.Details.VpnAddr != nil {
-				logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-			}
 			lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr)
 		}
 	}
 
+	remoteAllowList := lhh.lh.GetRemoteAllowList()
 	for _, a := range n.Details.V4AddrPorts {
-		punch(protoV4AddrPortToNetAddrPort(a))
+		b := protoV4AddrPortToNetAddrPort(a)
+		if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
+			punch(b, detailsVpnAddr)
+		}
 	}
 
 	for _, a := range n.Details.V6AddrPorts {
-		punch(protoV6AddrPortToNetAddrPort(a))
+		b := protoV6AddrPortToNetAddrPort(a)
+		if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
+			punch(b, detailsVpnAddr)
+		}
 	}
 
 	// This sends a nebula test packet to the host trying to contact us. In the case
 	// of a double nat or other difficult scenario, this may help establish
 	// a tunnel.
 	if lhh.lh.punchy.GetRespond() {
-		var queryVpnAddr netip.Addr
-		if n.Details.OldVpnAddr != 0 {
-			b := [4]byte{}
-			binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
-			queryVpnAddr = netip.AddrFrom4(b)
-		} else if n.Details.VpnAddr != nil {
-			queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
-		}
-
 		go func() {
 			time.Sleep(lhh.lh.punchy.GetRespondDelay())
 			if lhh.l.Level >= logrus.DebugLevel {
-				lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr)
+				lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr)
 			}
 			//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
 			// for each punchBack packet. We should move this into a timerwheel or a single goroutine
 			// managed by a channel.
-			w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 		}()
 	}
 }
@@ -1441,3 +1445,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
 	}
 	return netip.Addr{}, false
 }
+
+func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) {
+	if d.OldVpnAddr != 0 {
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], d.OldVpnAddr)
+		detailsVpnAddr := netip.AddrFrom4(b)
+		return detailsVpnAddr, cert.Version1, nil
+	} else if d.VpnAddr != nil {
+		detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr)
+		return detailsVpnAddr, cert.Version2, nil
+	} else {
+		return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
+	}
+}

+ 121 - 1
lighthouse_test.go

@@ -14,7 +14,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v3"
+	"go.yaml.in/yaml/v3"
 )
 
 func TestOldIPv4Only(t *testing.T) {
@@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) {
 	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
 	assert.False(t, ok)
 }
+
+func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) {
+	l := test.NewLogger()
+
+	myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
+
+	testSameHostNotStatic := netip.MustParseAddr("10.128.0.41")
+	testStaticHost := netip.MustParseAddr("10.128.0.42")
+	//myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+	c := config.NewC(l)
+	lh1 := "10.128.0.2"
+	c.Settings["lighthouse"] = map[string]any{
+		"hosts":    []any{lh1},
+		"interval": "1s",
+	}
+
+	c.Settings["listen"] = map[string]any{"port": 4242}
+	c.Settings["static_host_map"] = map[string]any{
+		lh1:           []any{"1.1.1.1:4242"},
+		"10.128.0.42": []any{"1.2.3.4:4242"},
+	}
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
+	lh.ifce = &mockEncWriter{}
+
+	//test that we actually have the static entry:
+	out := lh.Query(testStaticHost)
+	assert.NotNil(t, out)
+	assert.Equal(t, out.vpnAddrs[0], testStaticHost)
+	out.Rebuild([]netip.Prefix{}) //why tho
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+
+	//bolt on a lower numbered primary IP
+	am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost})
+	am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost}
+	lh.addrMap[testSameHostNotStatic] = am
+	out.Rebuild([]netip.Prefix{}) //???
+
+	//test that we actually have the static entry:
+	out = lh.Query(testStaticHost)
+	assert.NotNil(t, out)
+	assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
+	assert.Equal(t, out.vpnAddrs[1], testStaticHost)
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+
+	//test that we actually have the static entry for BOTH:
+	out2 := lh.Query(testSameHostNotStatic)
+	assert.Same(t, out2, out)
+
+	//now do the delete
+	lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost})
+	//verify
+	out = lh.Query(testSameHostNotStatic)
+	assert.NotNil(t, out)
+	if out == nil {
+		t.Fatal("expected non-nil query for the static host")
+	}
+	assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic)
+	assert.Equal(t, out.vpnAddrs[1], testStaticHost)
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+}
+
+func TestLighthouse_DeletesWork(t *testing.T) {
+	l := test.NewLogger()
+
+	myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242")
+	testHost := netip.MustParseAddr("10.128.0.42")
+
+	c := config.NewC(l)
+	lh1 := "10.128.0.2"
+	c.Settings["lighthouse"] = map[string]any{
+		"hosts":    []any{lh1},
+		"interval": "1s",
+	}
+
+	c.Settings["listen"] = map[string]any{"port": 4242}
+	c.Settings["static_host_map"] = map[string]any{
+		lh1: []any{"1.1.1.1:4242"},
+	}
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Lite)
+	nt.Insert(myVpnNet)
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
+	lh.ifce = &mockEncWriter{}
+
+	//insert the host
+	am := lh.unlockedGetRemoteList([]netip.Addr{testHost})
+	am.vpnAddrs = []netip.Addr{testHost}
+	am.addrs = []netip.AddrPort{myUdpAddr2}
+	lh.addrMap[testHost] = am
+	am.Rebuild([]netip.Prefix{}) //???
+
+	//test that we actually have the entry:
+	out := lh.Query(testHost)
+	assert.NotNil(t, out)
+	assert.Equal(t, out.vpnAddrs[0], testHost)
+	out.Rebuild([]netip.Prefix{}) //why tho
+	assert.Equal(t, out.addrs[0], myUdpAddr2)
+
+	//now do the delete
+	lh.DeleteVpnAddrs([]netip.Addr{testHost})
+	//verify
+	out = lh.Query(testHost)
+	assert.Nil(t, out)
+}

+ 24 - 2
main.go

@@ -5,6 +5,8 @@ import (
 	"fmt"
 	"net"
 	"net/netip"
+	"runtime/debug"
+	"strings"
 	"time"
 
 	"github.com/sirupsen/logrus"
@@ -13,7 +15,7 @@ import (
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
-	"gopkg.in/yaml.v3"
+	"go.yaml.in/yaml/v3"
 )
 
 type m = map[string]any
@@ -27,6 +29,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}()
 
+	if buildVersion == "" {
+		buildVersion = moduleVersion()
+	}
+
 	l := logger
 	l.Formatter = &logrus.TextFormatter{
 		FullTimestamp: true,
@@ -75,7 +81,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if c.GetBool("sshd.enabled", false) {
 		sshStart, err = configSSH(l, ssh, c)
 		if err != nil {
-			return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
+			l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available")
+			sshStart = nil
 		}
 	}
 
@@ -328,3 +335,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		connManager.Start,
 	}, nil
 }
+
+func moduleVersion() string {
+	info, ok := debug.ReadBuildInfo()
+	if !ok {
+		return ""
+	}
+
+	for _, dep := range info.Deps {
+		if dep.Path == "github.com/slackhq/nebula" {
+			return strings.TrimPrefix(dep.Version, "v")
+		}
+	}
+
+	return ""
+}

+ 51 - 47
outside.go

@@ -19,21 +19,21 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
-			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
+			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
 		}
 		return
 	}
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
-	if ip.IsValid() {
-		if f.myVpnNetworksTable.Contains(ip.Addr()) {
+	if !via.IsRelayed {
+		if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
 			if f.l.Level >= logrus.DebugLevel {
-				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
+				f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
 			}
 			return
 		}
@@ -54,8 +54,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	switch h.Type {
 	case header.Message:
-		// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
@@ -79,7 +78,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			// Successfully validated the thing. Get rid of the Relay header.
 			signedPayload = signedPayload[header.Len:]
 			// Pull the Roaming parts up here, and return in all call paths.
-			f.handleHostRoaming(hostinfo, ip)
+			f.handleHostRoaming(hostinfo, via)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			f.connectionManager.In(hostinfo)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -96,7 +95,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			case TerminalType:
 				// If I am the target of this relay, process the unwrapped packet
 				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-				f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+				via = ViaSender{
+					UdpAddr:   via.UdpAddr,
+					relayHI:   hostinfo,
+					remoteIdx: relay.RemoteIndex,
+					relay:     relay,
+					IsRelayed: true,
+				}
+				f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
@@ -126,31 +132,32 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	case header.LightHouse:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+			hostinfo.logger(f.l).WithError(err).WithField("from", via).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
 			return
 		}
 
-		lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
+		//TODO: assert via is not relayed
+		lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
 
 		// Fallthrough to the bottom to record incoming traffic
 
 	case header.Test:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+			hostinfo.logger(f.l).WithError(err).WithField("from", via).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
 			return
@@ -159,7 +166,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 		if h.Subtype == header.TestRequest {
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
-			f.handleHostRoaming(hostinfo, ip)
+			f.handleHostRoaming(hostinfo, via)
 			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
 		}
 
@@ -170,34 +177,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handshakeManager.HandleIncoming(ip, via, packet, h)
+		f.handshakeManager.HandleIncoming(via, packet, h)
 		return
 
 	case header.RecvError:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handleRecvError(ip, h)
+		f.handleRecvError(via.UdpAddr, h)
 		return
 
 	case header.CloseTunnel:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", ip).
+		hostinfo.logger(f.l).WithField("from", via).
 			Info("Close tunnel received, tearing down.")
 
 		f.closeTunnel(hostinfo)
 		return
 
 	case header.Control:
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+			hostinfo.logger(f.l).WithError(err).WithField("from", via).
 				WithField("packet", packet).
 				Error("Failed to decrypt Control packet")
 			return
@@ -207,11 +214,11 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
+		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
 		return
 	}
 
-	f.handleHostRoaming(hostinfo, ip)
+	f.handleHostRoaming(hostinfo, via)
 
 	f.connectionManager.In(hostinfo)
 }
@@ -230,50 +237,51 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
-	if udpAddr.IsValid() && hostinfo.remote != udpAddr {
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
+	if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
 		if hostinfo.multiportRx {
 			// If the remote is sending with multiport, we aren't roaming unless
 			// the IP has changed
-			if hostinfo.remote.Addr().Compare(udpAddr.Addr()) == 0 {
+			if hostinfo.remote.Addr().Compare(via.UdpAddr.Addr()) == 0 {
 				return
 			}
 			// Keep the port from the original hostinfo, because the remote is transmitting from multiport ports
-			udpAddr = netip.AddrPortFrom(udpAddr.Addr(), hostinfo.remote.Port())
+			via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
 		}
-
-		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
 
-		if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+		if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(udpAddr)
+		hostinfo.SetRemote(via.UdpAddr)
 	}
 
 }
 
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
-	// If connectionstate exists and the replay protector allows, process packet
-	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
-	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
-		if addr.IsValid() {
-			f.maybeSendRecvError(addr, h.RemoteIndex)
-			return false
-		} else {
-			return false
+// handleEncrypted returns true if a packet should be processed, false otherwise
+func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
+	// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
+	if ci == nil {
+		if !via.IsRelayed {
+			f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
 		}
+		return false
+	}
+	// If the window check fails, refuse to process the packet, but don't send a recv error
+	if !ci.window.Check(f.l, h.MessageCounter) {
+		return false
 	}
 
 	return true
@@ -547,10 +555,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 		return
 	}
 
-	if !hostinfo.RecvErrorExceeded() {
-		return
-	}
-
 	if hostinfo.remote.IsValid() && hostinfo.remote != addr {
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return

+ 1 - 0
overlay/device.go

@@ -13,5 +13,6 @@ type Device interface {
 	Networks() []netip.Prefix
 	Name() string
 	RoutesFor(netip.Addr) routing.Gateways
+	SupportsMultiqueue() bool
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }

+ 50 - 0
overlay/tun.go

@@ -1,6 +1,8 @@
 package overlay
 
 import (
+	"fmt"
+	"net"
 	"net/netip"
 
 	"github.com/sirupsen/logrus"
@@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route {
 
 	return removed
 }
+
+func prefixToMask(prefix netip.Prefix) netip.Addr {
+	pLen := 128
+	if prefix.Addr().Is4() {
+		pLen = 32
+	}
+
+	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
+	return addr
+}
+
+func flipBytes(b []byte) []byte {
+	for i := 0; i < len(b); i++ {
+		b[i] ^= 0xFF
+	}
+	return b
+}
+func orBytes(a []byte, b []byte) []byte {
+	ret := make([]byte, len(a))
+	for i := 0; i < len(a); i++ {
+		ret[i] = a[i] | b[i]
+	}
+	return ret
+}
+
+func getBroadcast(cidr netip.Prefix) netip.Addr {
+	broadcast, _ := netip.AddrFromSlice(
+		orBytes(
+			cidr.Addr().AsSlice(),
+			flipBytes(prefixToMask(cidr).AsSlice()),
+		),
+	)
+	return broadcast
+}
+
+func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) {
+	for _, gateway := range gateways {
+		if dest.Addr().Is4() && gateway.Addr().Is4() {
+			return gateway, nil
+		}
+
+		if dest.Addr().Is6() && gateway.Addr().Is6() {
+			return gateway, nil
+		}
+	}
+
+	return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
+}

+ 4 - 0
overlay/tun_android.go

@@ -95,6 +95,10 @@ func (t *tun) Name() string {
 	return "android"
 }
 
+func (t *tun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
 }

+ 4 - 12
overlay/tun_darwin.go

@@ -7,7 +7,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"net"
 	"net/netip"
 	"os"
 	"sync/atomic"
@@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error {
 			Vltime: 0xffffffff,
 			Pltime: 0xffffffff,
 		},
-		//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
 		Flags: _IN6_IFF_NODAD,
 	}
 
@@ -551,16 +549,10 @@ func (t *tun) Name() string {
 	return t.Device
 }
 
-func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
-	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
+func (t *tun) SupportsMultiqueue() bool {
+	return false
 }
 
-func prefixToMask(prefix netip.Prefix) netip.Addr {
-	pLen := 128
-	if prefix.Addr().Is4() {
-		pLen = 32
-	}
-
-	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
-	return addr
+func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }

+ 4 - 0
overlay/tun_disabled.go

@@ -105,6 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
 	return len(b), nil
 }
 
+func (t *disabledTun) SupportsMultiqueue() bool {
+	return true
+}
+
 func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return t, nil
 }

+ 383 - 68
overlay/tun_freebsd.go

@@ -10,11 +10,9 @@ import (
 	"io"
 	"io/fs"
 	"net/netip"
-	"os"
-	"os/exec"
-	"strconv"
 	"sync/atomic"
 	"syscall"
+	"time"
 	"unsafe"
 
 	"github.com/gaissmai/bart"
@@ -22,12 +20,18 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
+	netroute "golang.org/x/net/route"
+	"golang.org/x/sys/unix"
 )
 
 const (
 	// FIODGNAME is defined in sys/sys/filio.h on FreeBSD
 	// For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678)
-	FIODGNAME = 0x80106678
+	FIODGNAME        = 0x80106678
+	TUNSIFMODE       = 0x8004745e
+	TUNSIFHEAD       = 0x80047460
+	OSIOCAIFADDR_IN6 = 0x8088691b
+	IN6_IFF_NODAD    = 0x0020
 )
 
 type fiodgnameArg struct {
@@ -37,43 +41,159 @@ type fiodgnameArg struct {
 }
 
 type ifreqRename struct {
-	Name [16]byte
+	Name [unix.IFNAMSIZ]byte
 	Data uintptr
 }
 
 type ifreqDestroy struct {
-	Name [16]byte
+	Name [unix.IFNAMSIZ]byte
 	pad  [16]byte
 }
 
+type ifReq struct {
+	Name  [unix.IFNAMSIZ]byte
+	Flags uint16
+}
+
+type ifreqMTU struct {
+	Name [unix.IFNAMSIZ]byte
+	MTU  int32
+}
+
+type addrLifetime struct {
+	Expire    uint64
+	Preferred uint64
+	Vltime    uint32
+	Pltime    uint32
+}
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+	VHid     uint32
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+	VHid       uint32
+}
+
 type tun struct {
 	Device      string
 	vpnNetworks []netip.Prefix
 	MTU         int
 	Routes      atomic.Pointer[[]Route]
 	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
+	linkAddr    *netroute.LinkAddr
 	l           *logrus.Logger
+	devFd       int
+}
+
+func (t *tun) Read(to []byte) (int, error) {
+	// use readv() to read from the tunnel device, to eliminate the need for copying the buffer
+	if t.devFd < 0 {
+		return -1, syscall.EINVAL
+	}
+
+	// first 4 bytes is protocol family, in network byte order
+	head := make([]byte, 4)
 
-	io.ReadWriteCloser
+	iovecs := []syscall.Iovec{
+		{&head[0], 4},
+		{&to[0], uint64(len(to))},
+	}
+
+	n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
+
+	var err error
+	if errno != 0 {
+		err = syscall.Errno(errno)
+	} else {
+		err = nil
+	}
+	// fix bytes read number to exclude header
+	bytesRead := int(n)
+	if bytesRead < 0 {
+		return bytesRead, err
+	} else if bytesRead < 4 {
+		return 0, err
+	} else {
+		return bytesRead - 4, err
+	}
 }
 
-func (t *tun) Close() error {
-	if t.ReadWriteCloser != nil {
-		if err := t.ReadWriteCloser.Close(); err != nil {
-			return err
-		}
+// Write is only valid for single threaded use
+func (t *tun) Write(from []byte) (int, error) {
+	// use writev() to write to the tunnel device, to eliminate the need for copying the buffer
+	if t.devFd < 0 {
+		return -1, syscall.EINVAL
+	}
 
-		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+	if len(from) <= 1 {
+		return 0, syscall.EIO
+	}
+	ipVer := from[0] >> 4
+	var head []byte
+	// first 4 bytes is protocol family, in network byte order
+	if ipVer == 4 {
+		head = []byte{0, 0, 0, syscall.AF_INET}
+	} else if ipVer == 6 {
+		head = []byte{0, 0, 0, syscall.AF_INET6}
+	} else {
+		return 0, fmt.Errorf("unable to determine IP version from packet")
+	}
+	iovecs := []syscall.Iovec{
+		{&head[0], 4},
+		{&from[0], uint64(len(from))},
+	}
+
+	n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
+
+	var err error
+	if errno != 0 {
+		err = syscall.Errno(errno)
+	} else {
+		err = nil
+	}
+
+	return int(n) - 4, err
+}
+
+func (t *tun) Close() error {
+	if t.devFd >= 0 {
+		err := syscall.Close(t.devFd)
 		if err != nil {
-			return err
+			t.l.WithError(err).Error("Error closing device")
 		}
-		defer syscall.Close(s)
-
-		ifreq := ifreqDestroy{Name: t.deviceBytes()}
+		t.devFd = -1
+
+		c := make(chan struct{})
+		go func() {
+			// destroying the interface can block if a read() is still pending. Do this asynchronously.
+			defer close(c)
+			s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+			if err == nil {
+				defer syscall.Close(s)
+				ifreq := ifreqDestroy{Name: t.deviceBytes()}
+				err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
+			}
+			if err != nil {
+				t.l.WithError(err).Error("Error destroying tunnel")
+			}
+		}()
 
-		// Destroy the interface
-		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
-		return err
+		// wait up to 1 second so we start blocking at the ioctl
+		select {
+		case <-c:
+		case <-time.After(1 * time.Second):
+		}
 	}
 
 	return nil
@@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun,
 
 func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
-	var file *os.File
+	var fd int
 	var err error
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName != "" {
-		file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+		fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0)
 	}
 	if errors.Is(err, fs.ErrNotExist) || deviceName == "" {
 		// If the device doesn't already exist, request a new one and rename it
-		file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0)
+		fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0)
 	}
 	if err != nil {
 		return nil, err
 	}
 
-	rawConn, err := file.SyscallConn()
-	if err != nil {
-		return nil, fmt.Errorf("SyscallConn: %v", err)
+	// Read the name of the interface
+	var name [16]byte
+	arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
+	ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg)))
+
+	if ctrlErr == nil {
+		// set broadcast mode and multicast
+		ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST)
+		ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode)))
+	}
+
+	if ctrlErr == nil {
+		// turn on link-layer mode, to support ipv6
+		ifhead := uint32(1)
+		ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead)))
 	}
 
-	var name [16]byte
-	var ctrlErr error
-	rawConn.Control(func(fd uintptr) {
-		// Read the name of the interface
-		arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)}
-		ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg)))
-	})
 	if ctrlErr != nil {
 		return nil, err
 	}
@@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 
 	// If the name doesn't match the desired interface name, rename it now
 	if ifName != deviceName {
-		s, err := syscall.Socket(
-			syscall.AF_INET,
-			syscall.SOCK_DGRAM,
-			syscall.IPPROTO_IP,
-		)
+		s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
 		if err != nil {
 			return nil, err
 		}
@@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 	}
 
 	t := &tun{
-		ReadWriteCloser: file,
-		Device:          deviceName,
-		vpnNetworks:     vpnNetworks,
-		MTU:             c.GetInt("tun.mtu", DefaultMTU),
-		l:               l,
+		Device:      deviceName,
+		vpnNetworks: vpnNetworks,
+		MTU:         c.GetInt("tun.mtu", DefaultMTU),
+		l:           l,
+		devFd:       fd,
 	}
 
 	err = t.reload(c, true)
@@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 }
 
 func (t *tun) addIp(cidr netip.Prefix) error {
-	var err error
-	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	if cidr.Addr().Is4() {
+		ifr := ifreqAlias4{
+			Name: t.deviceBytes(),
+			Addr: unix.RawSockaddrInet4{
+				Len:    unix.SizeofSockaddrInet4,
+				Family: unix.AF_INET,
+				Addr:   cidr.Addr().As4(),
+			},
+			DstAddr: unix.RawSockaddrInet4{
+				Len:    unix.SizeofSockaddrInet4,
+				Family: unix.AF_INET,
+				Addr:   getBroadcast(cidr).As4(),
+			},
+			MaskAddr: unix.RawSockaddrInet4{
+				Len:    unix.SizeofSockaddrInet4,
+				Family: unix.AF_INET,
+				Addr:   prefixToMask(cidr).As4(),
+			},
+			VHid: 0,
+		}
+		s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+		// Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR
+		if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
+		}
+		return nil
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'route add': %s", err)
-	}
+	if cidr.Addr().Is6() {
+		ifr := ifreqAlias6{
+			Name: t.deviceBytes(),
+			Addr: unix.RawSockaddrInet6{
+				Len:    unix.SizeofSockaddrInet6,
+				Family: unix.AF_INET6,
+				Addr:   cidr.Addr().As16(),
+			},
+			PrefixMask: unix.RawSockaddrInet6{
+				Len:    unix.SizeofSockaddrInet6,
+				Family: unix.AF_INET6,
+				Addr:   prefixToMask(cidr).As16(),
+			},
+			Lifetime: addrLifetime{
+				Expire:    0,
+				Preferred: 0,
+				Vltime:    0xffffffff,
+				Pltime:    0xffffffff,
+			},
+			Flags: IN6_IFF_NODAD,
+		}
+		s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
 
-	cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+		if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
+		}
+		return nil
 	}
 
-	// Unsafe path routes
-	return t.addRoutes(false)
+	return fmt.Errorf("unknown address type %v", cidr)
 }
 
 func (t *tun) Activate() error {
+	// Setup our default MTU
+	err := t.setMTU()
+	if err != nil {
+		return err
+	}
+
+	linkAddr, err := getLinkAddr(t.Device)
+	if err != nil {
+		return err
+	}
+	if linkAddr == nil {
+		return fmt.Errorf("unable to discover link_addr for tun interface")
+	}
+	t.linkAddr = linkAddr
+
 	for i := range t.vpnNetworks {
 		err := t.addIp(t.vpnNetworks[i])
 		if err != nil {
 			return err
 		}
 	}
-	return nil
+
+	return t.addRoutes(false)
+}
+
+func (t *tun) setMTU() error {
+	// Set the MTU on the device
+	s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+	if err != nil {
+		return err
+	}
+	defer syscall.Close(s)
+
+	ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)}
+	err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm)))
+	return err
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
@@ -256,6 +450,10 @@ func (t *tun) Name() string {
 	return t.Device
 }
 
+func (t *tun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
 }
@@ -268,15 +466,16 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device)
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
+		err := addRoute(r.Cidr, t.linkAddr)
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {
 				return retErr
 			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
 		}
 	}
 
@@ -289,9 +488,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device)
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
+		err := delRoute(r.Cidr, t.linkAddr)
+		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 			t.l.WithField("route", r).Info("Removed route")
@@ -306,3 +504,120 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	}
 	return
 }
+
+func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_ADD,
+		Flags:   unix.RTF_UP,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		if errors.Is(err, unix.EEXIST) {
+			// Try to do a change
+			route.Type = unix.RTM_CHANGE
+			data, err = route.Marshal()
+			if err != nil {
+				return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
+			}
+			_, err = unix.Write(sock, data[:])
+			fmt.Println("DOING CHANGE")
+			return err
+		}
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_DELETE,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+// getLinkAddr Gets the link address for the interface of the given name
+func getLinkAddr(name string) (*netroute.LinkAddr, error) {
+	rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
+	if err != nil {
+		return nil, err
+	}
+	msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib)
+	if err != nil {
+		return nil, err
+	}
+
+	for _, m := range msgs {
+		switch m := m.(type) {
+		case *netroute.InterfaceMessage:
+			if m.Name == name {
+				sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr)
+				if ok {
+					return sa, nil
+				}
+			}
+		}
+	}
+
+	return nil, nil
+}

+ 4 - 0
overlay/tun_ios.go

@@ -151,6 +151,10 @@ func (t *tun) Name() string {
 	return "iOS"
 }
 
+func (t *tun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
 }

+ 52 - 28
overlay/tun_linux.go

@@ -216,6 +216,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
+func (t *tun) SupportsMultiqueue() bool {
+	return true
+}
+
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
@@ -293,7 +297,6 @@ func (t *tun) addIPs(link netlink.Link) error {
 
 	//add all new addresses
 	for i := range newAddrs {
-		//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
 		//AddrReplace still adds new IPs, but if their properties change it will change them as well
 		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
 			return err
@@ -361,6 +364,11 @@ func (t *tun) Activate() error {
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 
+	const modeNone = 1
+	if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil {
+		t.l.WithError(err).Warn("Failed to disable link local address generation")
+	}
+
 	if err = t.addIPs(link); err != nil {
 		return err
 	}
@@ -578,48 +586,42 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
 }
 
 func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
-
 	var gateways routing.Gateways
 
 	link, err := netlink.LinkByName(t.Device)
 	if err != nil {
-		t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
+		t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
 		return gateways
 	}
 
 	// If this route is relevant to our interface and there is a gateway then add it
-	if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
-		gwAddr, ok := netip.AddrFromSlice(r.Gw)
-		if !ok {
-			t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
-		} else {
-			gwAddr = gwAddr.Unmap()
-
-			if !t.isGatewayInVpnNetworks(gwAddr) {
-				// Gateway isn't in our overlay network, ignore
-				t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
-			} else {
+	if r.LinkIndex == link.Attrs().Index {
+		gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
+		if ok {
+			if t.isGatewayInVpnNetworks(gwAddr) {
 				gateways = append(gateways, routing.NewGateway(gwAddr, 1))
+			} else {
+				// Gateway isn't in our overlay network, ignore
+				t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
 			}
+		} else {
+			t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
 		}
 	}
 
 	for _, p := range r.MultiPath {
 		// If this route is relevant to our interface and there is a gateway then add it
-		if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
-			gwAddr, ok := netip.AddrFromSlice(p.Gw)
-			if !ok {
-				t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
-			} else {
-				gwAddr = gwAddr.Unmap()
-
-				if !t.isGatewayInVpnNetworks(gwAddr) {
-					// Gateway isn't in our overlay network, ignore
-					t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
-				} else {
-					// p.Hops+1 = weight of the route
+		if p.LinkIndex == link.Attrs().Index {
+			gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
+			if ok {
+				if t.isGatewayInVpnNetworks(gwAddr) {
 					gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
+				} else {
+					// Gateway isn't in our overlay network, ignore
+					t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
 				}
+			} else {
+				t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
 			}
 		}
 	}
@@ -628,16 +630,38 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
 	return gateways
 }
 
-func (t *tun) updateRoutes(r netlink.RouteUpdate) {
+func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
+	// Try to use the old RTA_GATEWAY first
+	gwAddr, ok := netip.AddrFromSlice(gw)
+	if !ok {
+		// Fallback to the new RTA_VIA
+		rVia, ok := via.(*netlink.Via)
+		if ok {
+			gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
+		}
+	}
 
-	gateways := t.getGatewaysFromRoute(&r.Route)
+	if gwAddr.IsValid() {
+		gwAddr = gwAddr.Unmap()
+		return gwAddr, true
+	}
+
+	return netip.Addr{}, false
+}
 
+func (t *tun) updateRoutes(r netlink.RouteUpdate) {
+	gateways := t.getGatewaysFromRoute(&r.Route)
 	if len(gateways) == 0 {
 		// No gateways relevant to our network, no routing changes required.
 		t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
 		return
 	}
 
+	if r.Dst == nil {
+		t.l.WithField("route", r).Debug("Ignoring route update, no destination address")
+		return
+	}
+
 	dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
 	if !ok {
 		t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")

+ 372 - 63
overlay/tun_netbsd.go

@@ -4,13 +4,12 @@
 package overlay
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"net/netip"
 	"os"
-	"os/exec"
 	"regexp"
-	"strconv"
 	"sync/atomic"
 	"syscall"
 	"unsafe"
@@ -20,11 +19,42 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
+	netroute "golang.org/x/net/route"
+	"golang.org/x/sys/unix"
 )
 
-type ifreqDestroy struct {
-	Name [16]byte
-	pad  [16]byte
+const (
+	SIOCAIFADDR_IN6 = 0x8080696b
+	TUNSIFHEAD      = 0x80047442
+	TUNSIFMODE      = 0x80047458
+)
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+}
+
+type ifreq struct {
+	Name [unix.IFNAMSIZ]byte
+	data int
+}
+
+type addrLifetime struct {
+	Expire    uint64
+	Preferred uint64
+	Vltime    uint32
+	Pltime    uint32
 }
 
 type tun struct {
@@ -34,40 +64,18 @@ type tun struct {
 	Routes      atomic.Pointer[[]Route]
 	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
-
-	io.ReadWriteCloser
+	f           *os.File
+	fd          int
 }
 
-func (t *tun) Close() error {
-	if t.ReadWriteCloser != nil {
-		if err := t.ReadWriteCloser.Close(); err != nil {
-			return err
-		}
-
-		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
-		if err != nil {
-			return err
-		}
-		defer syscall.Close(s)
-
-		ifreq := ifreqDestroy{Name: t.deviceBytes()}
-
-		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq)))
-
-		return err
-	}
-	return nil
-}
+var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 
-var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-
 func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
-	var file *os.File
 	var err error
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
@@ -77,17 +85,23 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 		return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
 	}
 
-	file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+	fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
 	if err != nil {
 		return nil, err
 	}
 
+	err = unix.SetNonblock(fd, true)
+	if err != nil {
+		l.WithError(err).Warn("Failed to set the tun device as nonblocking")
+	}
+
 	t := &tun{
-		ReadWriteCloser: file,
-		Device:          deviceName,
-		vpnNetworks:     vpnNetworks,
-		MTU:             c.GetInt("tun.mtu", DefaultMTU),
-		l:               l,
+		f:           os.NewFile(uintptr(fd), ""),
+		fd:          fd,
+		Device:      deviceName,
+		vpnNetworks: vpnNetworks,
+		MTU:         c.GetInt("tun.mtu", DefaultMTU),
+		l:           l,
 	}
 
 	err = t.reload(c, true)
@@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 	return t, nil
 }
 
-func (t *tun) addIp(cidr netip.Prefix) error {
-	var err error
+func (t *tun) Close() error {
+	if t.f != nil {
+		if err := t.f.Close(); err != nil {
+			return fmt.Errorf("error closing tun file: %w", err)
+		}
+
+		// t.f.Close should have handled it for us but let's be extra sure
+		_ = unix.Close(t.fd)
 
-	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+		s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		ifr := ifreq{Name: t.deviceBytes()}
+		err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr)))
+		return err
 	}
+	return nil
+}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'route add': %s", err)
+func (t *tun) Read(to []byte) (int, error) {
+	rc, err := t.f.SyscallConn()
+	if err != nil {
+		return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err)
 	}
 
-	cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
+	var errno syscall.Errno
+	var n uintptr
+	err = rc.Read(func(fd uintptr) bool {
+		// first 4 bytes is protocol family, in network byte order
+		head := [4]byte{}
+		iovecs := []syscall.Iovec{
+			{&head[0], 4},
+			{&to[0], uint64(len(to))},
+		}
+
+		n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
+		if errno.Temporary() {
+			// We got an EAGAIN, EINTR, or EWOULDBLOCK, go again
+			return false
+		}
+		return true
+	})
+	if err != nil {
+		if err == syscall.EBADF || err.Error() == "use of closed file" {
+			// Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are
+			// https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121
+			return 0, os.ErrClosed
+		}
+		return 0, fmt.Errorf("failed to make read call for tun: %w", err)
 	}
 
-	// Unsafe path routes
-	return t.addRoutes(false)
+	if errno != 0 {
+		return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno)
+	}
+
+	// fix bytes read number to exclude header
+	bytesRead := int(n)
+	if bytesRead < 0 {
+		return bytesRead, nil
+	} else if bytesRead < 4 {
+		return 0, nil
+	} else {
+		return bytesRead - 4, nil
+	}
+}
+
+// Write is only valid for single threaded use
+func (t *tun) Write(from []byte) (int, error) {
+	if len(from) <= 1 {
+		return 0, syscall.EIO
+	}
+
+	ipVer := from[0] >> 4
+	var head [4]byte
+	// first 4 bytes is protocol family, in network byte order
+	if ipVer == 4 {
+		head[3] = syscall.AF_INET
+	} else if ipVer == 6 {
+		head[3] = syscall.AF_INET6
+	} else {
+		return 0, fmt.Errorf("unable to determine IP version from packet")
+	}
+
+	rc, err := t.f.SyscallConn()
+	if err != nil {
+		return 0, err
+	}
+
+	var errno syscall.Errno
+	var n uintptr
+	err = rc.Write(func(fd uintptr) bool {
+		iovecs := []syscall.Iovec{
+			{&head[0], 4},
+			{&from[0], uint64(len(from))},
+		}
+
+		n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2))
+		// According to NetBSD documentation for TUN, writes will only return errors in which
+		// this packet will never be delivered so just go on living life.
+		return true
+	})
+	if err != nil {
+		return 0, err
+	}
+
+	if errno != 0 {
+		return 0, errno
+	}
+
+	return int(n) - 4, err
+}
+
+func (t *tun) addIp(cidr netip.Prefix) error {
+	if cidr.Addr().Is4() {
+		var req ifreqAlias4
+		req.Name = t.deviceBytes()
+		req.Addr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   cidr.Addr().As4(),
+		}
+		req.DstAddr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   cidr.Addr().As4(),
+		}
+		req.MaskAddr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(cidr).As4(),
+		}
+
+		s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
+		}
+
+		return nil
+	}
+
+	if cidr.Addr().Is6() {
+		var req ifreqAlias6
+		req.Name = t.deviceBytes()
+		req.Addr = unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   cidr.Addr().As16(),
+		}
+		req.PrefixMask = unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(cidr).As16(),
+		}
+		req.Lifetime = addrLifetime{
+			Vltime: 0xffffffff,
+			Pltime: 0xffffffff,
+		}
+
+		s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
+		}
+		return nil
+	}
+
+	return fmt.Errorf("unknown address type %v", cidr)
 }
 
 func (t *tun) Activate() error {
+	mode := int32(unix.IFF_BROADCAST)
+	err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode)))
+	if err != nil {
+		return fmt.Errorf("failed to set tun device mode: %w", err)
+	}
+
+	v := 1
+	err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v)))
+	if err != nil {
+		return fmt.Errorf("failed to set tun device head: %w", err)
+	}
+
+	err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
+	if err != nil {
+		return fmt.Errorf("failed to set tun mtu: %w", err)
+	}
+
 	for i := range t.vpnNetworks {
-		err := t.addIp(t.vpnNetworks[i])
+		err = t.addIp(t.vpnNetworks[i])
 		if err != nil {
 			return err
 		}
 	}
-	return nil
+
+	return t.addRoutes(false)
+}
+
+func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
+	s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+	if err != nil {
+		return err
+	}
+	defer syscall.Close(s)
+
+	ir := ifreq{Name: t.deviceBytes(), data: int(value)}
+	err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
+	return err
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
@@ -191,27 +390,33 @@ func (t *tun) Name() string {
 	return t.Device
 }
 
+func (t *tun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
 }
 
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
+		err := addRoute(r.Cidr, t.vpnNetworks)
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {
 				return retErr
 			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
 		}
 	}
 
@@ -224,10 +429,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		//TODO: CERT-V2 is this right?
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
+		err := delRoute(r.Cidr, t.vpnNetworks)
+		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 			t.l.WithField("route", r).Info("Removed route")
@@ -242,3 +445,109 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	}
 	return
 }
+
+func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_ADD,
+		Flags:   unix.RTF_UP | unix.RTF_GATEWAY,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
+		}
+	} else {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		if errors.Is(err, unix.EEXIST) {
+			// Try to do a change
+			route.Type = unix.RTM_CHANGE
+			data, err = route.Marshal()
+			if err != nil {
+				return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
+			}
+			_, err = unix.Write(sock, data[:])
+			return err
+		}
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}
+
+func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_DELETE,
+		Seq:     1,
+	}
+
+	if prefix.Addr().Is4() {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
+		}
+	} else {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
+		}
+	}
+
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
+}

+ 315 - 97
overlay/tun_openbsd.go

@@ -4,23 +4,50 @@
 package overlay
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"net/netip"
 	"os"
-	"os/exec"
 	"regexp"
-	"strconv"
 	"sync/atomic"
 	"syscall"
+	"unsafe"
 
 	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/routing"
 	"github.com/slackhq/nebula/util"
+	netroute "golang.org/x/net/route"
+	"golang.org/x/sys/unix"
 )
 
+const (
+	SIOCAIFADDR_IN6 = 0x8080691a
+)
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   [2]uint32
+}
+
+type ifreq struct {
+	Name [unix.IFNAMSIZ]byte
+	data int
+}
+
 type tun struct {
 	Device      string
 	vpnNetworks []netip.Prefix
@@ -28,48 +55,46 @@ type tun struct {
 	Routes      atomic.Pointer[[]Route]
 	routeTree   atomic.Pointer[bart.Table[routing.Gateways]]
 	l           *logrus.Logger
-
-	io.ReadWriteCloser
-
+	f           *os.File
+	fd          int
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	out []byte
 }
 
-func (t *tun) Close() error {
-	if t.ReadWriteCloser != nil {
-		return t.ReadWriteCloser.Close()
-	}
-
-	return nil
-}
+var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
-	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
+	return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
 }
 
-var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
-
 func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
+	// Try to open tun device
+	var err error
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
-		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
+		return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
 	}
-
 	if !deviceNameRE.MatchString(deviceName) {
-		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
+		return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified")
 	}
 
-	file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0)
+	fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0)
 	if err != nil {
 		return nil, err
 	}
 
+	err = unix.SetNonblock(fd, true)
+	if err != nil {
+		l.WithError(err).Warn("Failed to set the tun device as nonblocking")
+	}
+
 	t := &tun{
-		ReadWriteCloser: file,
-		Device:          deviceName,
-		vpnNetworks:     vpnNetworks,
-		MTU:             c.GetInt("tun.mtu", DefaultMTU),
-		l:               l,
+		f:           os.NewFile(uintptr(fd), ""),
+		fd:          fd,
+		Device:      deviceName,
+		vpnNetworks: vpnNetworks,
+		MTU:         c.GetInt("tun.mtu", DefaultMTU),
+		l:           l,
 	}
 
 	err = t.reload(c, true)
@@ -87,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
 	return t, nil
 }
 
+func (t *tun) Close() error {
+	if t.f != nil {
+		if err := t.f.Close(); err != nil {
+			return fmt.Errorf("error closing tun file: %w", err)
+		}
+
+		// t.f.Close should have handled it for us but let's be extra sure
+		_ = unix.Close(t.fd)
+	}
+	return nil
+}
+
+func (t *tun) Read(to []byte) (int, error) {
+	buf := make([]byte, len(to)+4)
+
+	n, err := t.f.Read(buf)
+
+	copy(to, buf[4:])
+	return n - 4, err
+}
+
+// Write is only valid for single threaded use
+func (t *tun) Write(from []byte) (int, error) {
+	buf := t.out
+	if cap(buf) < len(from)+4 {
+		buf = make([]byte, len(from)+4)
+		t.out = buf
+	}
+	buf = buf[:len(from)+4]
+
+	if len(from) == 0 {
+		return 0, syscall.EIO
+	}
+
+	// Determine the IP Family for the NULL L2 Header
+	ipVer := from[0] >> 4
+	if ipVer == 4 {
+		buf[3] = syscall.AF_INET
+	} else if ipVer == 6 {
+		buf[3] = syscall.AF_INET6
+	} else {
+		return 0, fmt.Errorf("unable to determine IP version from packet")
+	}
+
+	copy(buf[4:], from)
+
+	n, err := t.f.Write(buf)
+	return n - 4, err
+}
+
+func (t *tun) addIp(cidr netip.Prefix) error {
+	if cidr.Addr().Is4() {
+		var req ifreqAlias4
+		req.Name = t.deviceBytes()
+		req.Addr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   cidr.Addr().As4(),
+		}
+		req.DstAddr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   cidr.Addr().As4(),
+		}
+		req.MaskAddr = unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(cidr).As4(),
+		}
+
+		s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err)
+		}
+
+		err = addRoute(cidr, t.vpnNetworks)
+		if err != nil {
+			return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err)
+		}
+
+		return nil
+	}
+
+	if cidr.Addr().Is6() {
+		var req ifreqAlias6
+		req.Name = t.deviceBytes()
+		req.Addr = unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   cidr.Addr().As16(),
+		}
+		req.PrefixMask = unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(cidr).As16(),
+		}
+		req.Lifetime[0] = 0xffffffff
+		req.Lifetime[1] = 0xffffffff
+
+		s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+		if err != nil {
+			return err
+		}
+		defer syscall.Close(s)
+
+		if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil {
+			return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err)
+		}
+
+		return nil
+	}
+
+	return fmt.Errorf("unknown address type %v", cidr)
+}
+
+func (t *tun) Activate() error {
+	err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU))
+	if err != nil {
+		return fmt.Errorf("failed to set tun mtu: %w", err)
+	}
+
+	for i := range t.vpnNetworks {
+		err = t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+
+	return t.addRoutes(false)
+}
+
+func (t *tun) doIoctlByName(ctl uintptr, value uint32) error {
+	s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
+	if err != nil {
+		return err
+	}
+	defer syscall.Close(s)
+
+	ir := ifreq{Name: t.deviceBytes(), data: int(value)}
+	err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir)))
+	return err
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
@@ -124,63 +297,46 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) addIp(cidr netip.Prefix) error {
-	var err error
-	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
-	}
-
-	cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU))
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'ifconfig': %s", err)
-	}
+func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
+	r, _ := t.routeTree.Load().Lookup(ip)
+	return r
+}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
-	t.l.Debug("command: ", cmd.String())
-	if err = cmd.Run(); err != nil {
-		return fmt.Errorf("failed to run 'route add': %s", err)
-	}
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
 
-	// Unsafe path routes
-	return t.addRoutes(false)
+func (t *tun) Name() string {
+	return t.Device
 }
 
-func (t *tun) Activate() error {
-	for i := range t.vpnNetworks {
-		err := t.addIp(t.vpnNetworks[i])
-		if err != nil {
-			return err
-		}
-	}
-	return nil
+func (t *tun) SupportsMultiqueue() bool {
+	return false
 }
 
-func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
-	r, _ := t.routeTree.Load().Lookup(ip)
-	return r
+func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
 }
 
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 		if len(r.Via) == 0 || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
-		//TODO: CERT-V2 is this right?
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
-			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err)
+
+		err := addRoute(r.Cidr, t.vpnNetworks)
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err)
 			if logErrors {
 				retErr.Log(t.l)
 			} else {
 				return retErr
 			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
 		}
 	}
 
@@ -192,10 +348,9 @@ func (t *tun) removeRoutes(routes []Route) error {
 		if !r.Install {
 			continue
 		}
-		//TODO: CERT-V2 is this right?
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
-		t.l.Debug("command: ", cmd.String())
-		if err := cmd.Run(); err != nil {
+
+		err := delRoute(r.Cidr, t.vpnNetworks)
+		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 			t.l.WithField("route", r).Info("Removed route")
@@ -204,52 +359,115 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *tun) Networks() []netip.Prefix {
-	return t.vpnNetworks
+func (t *tun) deviceBytes() (o [16]byte) {
+	for i, c := range t.Device {
+		o[i] = byte(c)
+	}
+	return
 }
 
-func (t *tun) Name() string {
-	return t.Device
-}
+func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
 
-func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
-	return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
-}
+	route := &netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_ADD,
+		Flags:   unix.RTF_UP | unix.RTF_GATEWAY,
+		Seq:     1,
+	}
 
-func (t *tun) Read(to []byte) (int, error) {
-	buf := make([]byte, len(to)+4)
+	if prefix.Addr().Is4() {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
+		}
+	} else {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
+		}
+	}
 
-	n, err := t.ReadWriteCloser.Read(buf)
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
 
-	copy(to, buf[4:])
-	return n - 4, err
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		if errors.Is(err, unix.EEXIST) {
+			// Try to do a change
+			route.Type = unix.RTM_CHANGE
+			data, err = route.Marshal()
+			if err != nil {
+				return fmt.Errorf("failed to create route.RouteMessage for change: %w", err)
+			}
+			_, err = unix.Write(sock, data[:])
+			return err
+		}
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
+
+	return nil
 }
 
-// Write is only valid for single threaded use
-func (t *tun) Write(from []byte) (int, error) {
-	buf := t.out
-	if cap(buf) < len(from)+4 {
-		buf = make([]byte, len(from)+4)
-		t.out = buf
+func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
 	}
-	buf = buf[:len(from)+4]
+	defer unix.Close(sock)
 
-	if len(from) == 0 {
-		return 0, syscall.EIO
+	route := netroute.RouteMessage{
+		Version: unix.RTM_VERSION,
+		Type:    unix.RTM_DELETE,
+		Seq:     1,
 	}
 
-	// Determine the IP Family for the NULL L2 Header
-	ipVer := from[0] >> 4
-	if ipVer == 4 {
-		buf[3] = syscall.AF_INET
-	} else if ipVer == 6 {
-		buf[3] = syscall.AF_INET6
+	if prefix.Addr().Is4() {
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()},
+		}
 	} else {
-		return 0, fmt.Errorf("unable to determine IP version from packet")
+		gw, err := selectGateway(prefix, gateways)
+		if err != nil {
+			return err
+		}
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()},
+		}
 	}
 
-	copy(buf[4:], from)
+	data, err := route.Marshal()
+	if err != nil {
+		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
+	}
+	_, err = unix.Write(sock, data[:])
+	if err != nil {
+		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
+	}
 
-	n, err := t.ReadWriteCloser.Write(buf)
-	return n - 4, err
+	return nil
 }

+ 4 - 0
overlay/tun_tester.go

@@ -132,6 +132,10 @@ func (t *TestTun) Read(b []byte) (int, error) {
 	return len(p), nil
 }
 
+func (t *TestTun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented")
 }

+ 4 - 0
overlay/tun_windows.go

@@ -234,6 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
 	return t.tun.Write(b, 0)
 }
 
+func (t *winTun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
 }

+ 4 - 0
overlay/user.go

@@ -46,6 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
 	return routing.Gateways{routing.NewGateway(ip, 1)}
 }
 
+func (d *UserDevice) SupportsMultiqueue() bool {
+	return true
+}
+
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return d, nil
 }

+ 1 - 0
pkclient/pkclient_cgo.go

@@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) {
 		pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true),
 		pkcs11.NewAttribute(pkcs11.CKA_WRAP, true),
 		pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true),
+		pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize),
 	}
 
 	// Set up the parameters which include the peer's public key

+ 54 - 42
pki.go

@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 		currentState := p.cs.Load()
 		if newState.v1Cert != nil {
 			if currentState.v1Cert == nil {
-				return util.NewContextualError("v1 certificate was added, restart required", nil, err)
+				//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
+			} else {
+				// did IP in cert change? if so, don't set
+				if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
+					return util.NewContextualError(
+						"Networks in new cert was different from old",
+						m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1},
+						nil,
+					)
+				}
+
+				if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
+					return util.NewContextualError(
+						"Curve in new v1 cert was different from old",
+						m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1},
+						nil,
+					)
+				}
 			}
-
-			// did IP in cert change? if so, don't set
-			if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
-				return util.NewContextualError(
-					"Networks in new cert was different from old",
-					m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
-					nil,
-				)
-			}
-
-			if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
-				return util.NewContextualError(
-					"Curve in new cert was different from old",
-					m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
-					nil,
-				)
-			}
-
-		} else if currentState.v1Cert != nil {
-			//TODO: CERT-V2 we should be able to tear this down
-			return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
 		}
 
 		if newState.v2Cert != nil {
 			if currentState.v2Cert == nil {
-				return util.NewContextualError("v2 certificate was added, restart required", nil, err)
+				//adding certs is fine, actually
+			} else {
+				// did IP in cert change? if so, don't set
+				if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
+					return util.NewContextualError(
+						"Networks in new cert was different from old",
+						m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2},
+						nil,
+					)
+				}
+
+				if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+					return util.NewContextualError(
+						"Curve in new cert was different from old",
+						m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2},
+						nil,
+					)
+				}
 			}
 
-			// did IP in cert change? if so, don't set
-			if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
-				return util.NewContextualError(
-					"Networks in new cert was different from old",
-					m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
-					nil,
-				)
+		} else if currentState.v2Cert != nil {
+			//newState.v1Cert is non-nil bc empty certstates aren't permitted
+			if newState.v1Cert == nil {
+				return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
 			}
-
-			if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+			//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
+			if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
 				return util.NewContextualError(
-					"Curve in new cert was different from old",
-					m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
+					"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
+					m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()},
 					nil,
 				)
 			}
-
-		} else if currentState.v2Cert != nil {
-			return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
 		}
 
 		// Cipher cant be hot swapped so just leave it at what it was before
@@ -173,7 +180,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 
 	p.cs.Store(newState)
 
-	//TODO: CERT-V2 newState needs a stringer that does json
 	if initial {
 		p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
 	} else {
@@ -359,7 +365,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
 			return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
 		}
 
-		//TODO: CERT-V2 make sure v2 has v1s address
+		if v1.Networks()[0] != v2.Networks()[0] {
+			return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil)
+		}
 
 		cs.initiatingVersion = dv
 	}
@@ -515,9 +523,13 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
 		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
 	}
 
-	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
-		l.WithField("fingerprint", fp).Info("Blocklisting cert")
-		caPool.BlocklistFingerprint(fp)
+	bl := c.GetStringSlice("pki.blocklist", []string{})
+	if len(bl) > 0 {
+		for _, fp := range bl {
+			caPool.BlocklistFingerprint(fp)
+		}
+
+		l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates")
 	}
 
 	return caPool, nil

+ 21 - 10
remote_list.go

@@ -190,7 +190,7 @@ type RemoteList struct {
 	// The full list of vpn addresses assigned to this host
 	vpnAddrs []netip.Addr
 
-	// A deduplicated set of addresses. Any accessor should lock beforehand.
+	// A deduplicated set of underlay addresses. Any accessor should lock beforehand.
 	addrs []netip.AddrPort
 
 	// A set of relay addresses. VpnIp addresses that the remote identified as relays.
@@ -201,8 +201,10 @@ type RemoteList struct {
 	// For learned addresses, this is the vpnIp that sent the packet
 	cache map[netip.Addr]*cache
 
-	hr        *hostnamesResults
-	shouldAdd func(netip.Addr) bool
+	hr *hostnamesResults
+
+	// shouldAdd is a nillable function that decides if x should be added to addrs.
+	shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool
 
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
@@ -213,7 +215,7 @@ type RemoteList struct {
 }
 
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
+func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList {
 	r := &RemoteList{
 		vpnAddrs:  make([]netip.Addr, len(vpnAddrs)),
 		addrs:     make([]netip.AddrPort, 0),
@@ -336,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
 }
 
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
-	if !bad.IsValid() {
-		// relays can have nil udp Addrs
+func (r *RemoteList) BlockRemote(bad ViaSender) {
+	if bad.IsRelayed {
 		return
 	}
+
 	r.Lock()
 	defer r.Unlock()
 
 	// Check if we already blocked this addr
-	if r.unlockedIsBad(bad) {
+	if r.unlockedIsBad(bad.UdpAddr) {
 		return
 	}
 
 	// We copy here because we are taking something else's memory and we can't trust everything
-	r.badRemotes = append(r.badRemotes, bad)
+	r.badRemotes = append(r.badRemotes, bad.UdpAddr)
 
 	// Mark the next interaction must recollect/dedupe
 	r.shouldRebuild = true
@@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
 	return c
 }
 
+// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake
+func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) {
+	r.Lock()
+	r.badRemotes = nil
+	r.vpnAddrs = make([]netip.Addr, len(vpnAddrs))
+	copy(r.vpnAddrs, vpnAddrs)
+	r.Unlock()
+}
+
 // ResetBlockedRemotes locks and clears the blocked remotes list
 func (r *RemoteList) ResetBlockedRemotes() {
 	r.Lock()
@@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() {
 
 	dnsAddrs := r.hr.GetAddrs()
 	for _, addr := range dnsAddrs {
-		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
+		if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) {
 			if !r.unlockedIsBad(addr) {
 				addrs = append(addrs, addr)
 			}

+ 1 - 1
service/service_test.go

@@ -16,8 +16,8 @@ import (
 	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/overlay"
+	"go.yaml.in/yaml/v3"
 	"golang.org/x/sync/errgroup"
-	"gopkg.in/yaml.v3"
 )
 
 type m = map[string]any

+ 4 - 0
test/tun.go

@@ -34,6 +34,10 @@ func (NoopTun) Write([]byte) (int, error) {
 	return 0, nil
 }
 
+func (NoopTun) SupportsMultiqueue() bool {
+	return false
+}
+
 func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, errors.New("unsupported")
 }

+ 4 - 0
udp/conn.go

@@ -19,6 +19,7 @@ type Conn interface {
 	ListenOut(r EncReader)
 	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
+	SupportsMultipleReaders() bool
 	Close() error
 }
 
@@ -33,6 +34,9 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
 func (NoopConn) ListenOut(_ EncReader) {
 	return
 }
+func (NoopConn) SupportsMultipleReaders() bool {
+	return false
+}
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 	return nil
 }

+ 7 - 3
udp/udp_darwin.go

@@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
 			return ErrInvalidIPv6RemoteForSocket
 		}
 
-		var rsa unix.RawSockaddrInet6
-		rsa.Family = unix.AF_INET6
-		rsa.Addr = ap.Addr().As16()
+		var rsa unix.RawSockaddrInet4
+		rsa.Family = unix.AF_INET
+		rsa.Addr = ap.Addr().As4()
 		binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
 		sa = unsafe.Pointer(&rsa)
 		addrLen = syscall.SizeofSockaddrInet4
@@ -184,6 +184,10 @@ func (u *StdConn) ListenOut(r EncReader) {
 	}
 }
 
+func (u *StdConn) SupportsMultipleReaders() bool {
+	return false
+}
+
 func (u *StdConn) Rebind() error {
 	var err error
 	if u.isV4 {

+ 4 - 0
udp/udp_generic.go

@@ -85,3 +85,7 @@ func (u *GenericConn) ListenOut(r EncReader) {
 		r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
 	}
 }
+
+func (u *GenericConn) SupportsMultipleReaders() bool {
+	return false
+}

+ 4 - 0
udp/udp_linux.go

@@ -72,6 +72,10 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
 	return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
 }
 
+func (u *StdConn) SupportsMultipleReaders() bool {
+	return true
+}
+
 func (u *StdConn) Rebind() error {
 	return nil
 }

+ 4 - 0
udp/udp_rio_windows.go

@@ -315,6 +315,10 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
 
 }
 
+func (u *RIOConn) SupportsMultipleReaders() bool {
+	return false
+}
+
 func (u *RIOConn) Rebind() error {
 	return nil
 }

+ 4 - 0
udp/udp_tester.go

@@ -127,6 +127,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
 	return u.Addr, nil
 }
 
+func (u *TesterConn) SupportsMultipleReaders() bool {
+	return false
+}
+
 func (u *TesterConn) Rebind() error {
 	return nil
 }