Browse Source

Merge remote-tracking branch 'origin/master' into multiport

Wade Simmons 5 months ago
parent
commit
f36db374ac
100 changed files with 9522 additions and 5597 deletions
  1. 1 1
      .github/workflows/gofmt.yml
  2. 3 3
      .github/workflows/release.yml
  3. 3 0
      .github/workflows/smoke-extra.yml
  4. 1 1
      .github/workflows/smoke.yml
  5. 18 18
      .github/workflows/smoke/smoke-vagrant.sh
  6. 22 4
      .github/workflows/test.yml
  7. 3 1
      .gitignore
  8. 14 5
      Makefile
  9. 1 1
      README.md
  10. 25 12
      allow_list.go
  11. 34 25
      calculated_remote.go
  12. 61 5
      calculated_remote_test.go
  13. 1 1
      cert/Makefile
  14. 15 4
      cert/README.md
  15. 52 0
      cert/asn1.go
  16. 0 140
      cert/ca.go
  17. 296 0
      cert/ca_pool.go
  18. 559 0
      cert/ca_pool_test.go
  19. 104 968
      cert/cert.go
  20. 0 1230
      cert/cert_test.go
  21. 489 0
      cert/cert_v1.go
  22. 111 111
      cert/cert_v1.pb.go
  23. 0 0
      cert/cert_v1.proto
  24. 218 0
      cert/cert_v1_test.go
  25. 37 0
      cert/cert_v2.asn1
  26. 730 0
      cert/cert_v2.go
  27. 267 0
      cert/cert_v2_test.go
  28. 159 2
      cert/crypto.go
  29. 87 0
      cert/crypto_test.go
  30. 41 6
      cert/errors.go
  31. 141 0
      cert/helper_test.go
  32. 161 0
      cert/pem.go
  33. 292 0
      cert/pem_test.go
  34. 167 0
      cert/sign.go
  35. 90 0
      cert/sign_test.go
  36. 138 0
      cert_test/cert.go
  37. 137 73
      cmd/nebula-cert/ca.go
  38. 36 30
      cmd/nebula-cert/ca_test.go
  39. 49 20
      cmd/nebula-cert/keygen.go
  40. 6 5
      cmd/nebula-cert/keygen_test.go
  41. 10 3
      cmd/nebula-cert/main_test.go
  42. 15 0
      cmd/nebula-cert/p11_cgo.go
  43. 16 0
      cmd/nebula-cert/p11_stub.go
  44. 14 9
      cmd/nebula-cert/print.go
  45. 144 21
      cmd/nebula-cert/print_test.go
  46. 238 110
      cmd/nebula-cert/sign.go
  47. 74 75
      cmd/nebula-cert/sign_test.go
  48. 26 15
      cmd/nebula-cert/verify.go
  49. 13 30
      cmd/nebula-cert/verify_test.go
  50. 0 3
      config/config_test.go
  51. 61 36
      connection_manager.go
  52. 162 61
      connection_manager_test.go
  53. 32 23
      connection_state.go
  54. 37 36
      control.go
  55. 22 36
      control_test.go
  56. 43 22
      control_tester.go
  57. 79 39
      dns_server.go
  58. 20 5
      dns_server_test.go
  59. 389 168
      e2e/handshakes_test.go
  60. 0 125
      e2e/helpers.go
  61. 75 38
      e2e/helpers_test.go
  62. 5 4
      e2e/router/hostmap.go
  63. 44 22
      e2e/router/router.go
  64. 12 5
      examples/config.yml
  65. 96 94
      firewall.go
  66. 11 10
      firewall/packet.go
  67. 136 195
      firewall_test.go
  68. 21 19
      go.mod
  69. 42 37
      go.sum
  70. 223 104
      handshake_ix.go
  71. 181 124
      handshake_manager.go
  72. 18 11
      handshake_manager_test.go
  73. 171 104
      hostmap.go
  74. 23 33
      hostmap_test.go
  75. 2 2
      hostmap_tester.go
  76. 31 27
      inside.go
  77. 84 77
      interface.go
  78. 0 2
      iputil/packet.go
  79. 364 240
      lighthouse.go
  80. 173 146
      lighthouse_test.go
  81. 5 34
      main.go
  82. 0 2
      message_metrics.go
  83. 560 176
      nebula.pb.go
  84. 23 9
      nebula.proto
  85. 50 0
      noiseutil/pkcs11.go
  86. 156 137
      outside.go
  87. 525 17
      outside_test.go
  88. 1 1
      overlay/device.go
  89. 22 12
      overlay/route.go
  90. 39 33
      overlay/route_test.go
  91. 9 9
      overlay/tun.go
  92. 11 11
      overlay/tun_android.go
  93. 208 215
      overlay/tun_darwin.go
  94. 8 8
      overlay/tun_disabled.go
  95. 25 15
      overlay/tun_freebsd.go
  96. 10 10
      overlay/tun_ios.go
  97. 120 78
      overlay/tun_linux.go
  98. 28 17
      overlay/tun_netbsd.go
  99. 29 19
      overlay/tun_openbsd.go
  100. 17 17
      overlay/tun_tester.go

+ 1 - 1
.github/workflows/gofmt.yml

@@ -18,7 +18,7 @@ jobs:
 
 
     - uses: actions/setup-go@v5
     - uses: actions/setup-go@v5
       with:
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
         check-latest: true
 
 
     - name: Install goimports
     - name: Install goimports

+ 3 - 3
.github/workflows/release.yml

@@ -14,7 +14,7 @@ jobs:
 
 
       - uses: actions/setup-go@v5
       - uses: actions/setup-go@v5
         with:
         with:
-          go-version: '1.22'
+          go-version: '1.23'
           check-latest: true
           check-latest: true
 
 
       - name: Build
       - name: Build
@@ -37,7 +37,7 @@ jobs:
 
 
       - uses: actions/setup-go@v5
       - uses: actions/setup-go@v5
         with:
         with:
-          go-version: '1.22'
+          go-version: '1.23'
           check-latest: true
           check-latest: true
 
 
       - name: Build
       - name: Build
@@ -70,7 +70,7 @@ jobs:
 
 
       - uses: actions/setup-go@v5
       - uses: actions/setup-go@v5
         with:
         with:
-          go-version: '1.22'
+          go-version: '1.23'
           check-latest: true
           check-latest: true
 
 
       - name: Import certificates
       - name: Import certificates

+ 3 - 0
.github/workflows/smoke-extra.yml

@@ -27,6 +27,9 @@ jobs:
         go-version-file: 'go.mod'
         go-version-file: 'go.mod'
         check-latest: true
         check-latest: true
 
 
+    - name: add hashicorp source
+      run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
+
     - name: install vagrant
     - name: install vagrant
       run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
       run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
 
 

+ 1 - 1
.github/workflows/smoke.yml

@@ -22,7 +22,7 @@ jobs:
 
 
     - uses: actions/setup-go@v5
     - uses: actions/setup-go@v5
       with:
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
         check-latest: true
 
 
     - name: build
     - name: build

+ 18 - 18
.github/workflows/smoke/smoke-vagrant.sh

@@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
 docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
 docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
 
 
 vagrant up
 vagrant up
-vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
+vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
 
 
 docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
 sleep 1
 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
 sleep 1
-vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
+vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 15
 sleep 15
 
 
 # grab tcpdump pcaps for debugging
 # grab tcpdump pcaps for debugging
@@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
 # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
 # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
 # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
 # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
 
 
-docker exec host2 ncat -nklv 0.0.0.0 2000 &
-vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
+#docker exec host2 ncat -nklv 0.0.0.0 2000 &
+#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
 #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
 #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
 #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
 #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
 
 
@@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
 # Should fail because not allowed by host3 inbound firewall
 # Should fail because not allowed by host3 inbound firewall
 ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
 ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
 
 
-set +x
-echo
-echo " *** Testing ncat from host2"
-echo
-set -x
+#set +x
+#echo
+#echo " *** Testing ncat from host2"
+#echo
+#set -x
 # Should fail because not allowed by host3 inbound firewall
 # Should fail because not allowed by host3 inbound firewall
 #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
 #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
@@ -82,18 +82,18 @@ echo
 echo " *** Testing ping from host3"
 echo " *** Testing ping from host3"
 echo
 echo
 set -x
 set -x
-vagrant ssh -c "ping -c1 192.168.100.1"
-vagrant ssh -c "ping -c1 192.168.100.2"
-
-set +x
-echo
-echo " *** Testing ncat from host3"
-echo
-set -x
+vagrant ssh -c "ping -c1 192.168.100.1" -- -T
+vagrant ssh -c "ping -c1 192.168.100.2" -- -T
+
+#set +x
+#echo
+#echo " *** Testing ncat from host3"
+#echo
+#set -x
 #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
 #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
 #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
 #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
 
 
-vagrant ssh -c "sudo xargs kill </nebula/pid"
+vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
 docker exec host2 sh -c 'kill 1'
 docker exec host2 sh -c 'kill 1'
 docker exec lighthouse1 sh -c 'kill 1'
 docker exec lighthouse1 sh -c 'kill 1'
 sleep 1
 sleep 1

+ 22 - 4
.github/workflows/test.yml

@@ -22,7 +22,7 @@ jobs:
 
 
     - uses: actions/setup-go@v5
     - uses: actions/setup-go@v5
       with:
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
         check-latest: true
 
 
     - name: Build
     - name: Build
@@ -55,7 +55,7 @@ jobs:
 
 
     - uses: actions/setup-go@v5
     - uses: actions/setup-go@v5
       with:
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
         check-latest: true
 
 
     - name: Build
     - name: Build
@@ -65,7 +65,25 @@ jobs:
       run: make test-boringcrypto
       run: make test-boringcrypto
 
 
     - name: End 2 end
     - name: End 2 end
-      run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+      run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
+
+  test-linux-pkcs11:
+    name: Build and test on linux with pkcs11
+    runs-on: ubuntu-latest
+    steps:
+
+    - uses: actions/checkout@v4
+
+    - uses: actions/setup-go@v5
+      with:
+        go-version: '1.22'
+        check-latest: true
+
+    - name: Build
+      run: make bin-pkcs11
+
+    - name: Test
+      run: make test-pkcs11
 
 
   test:
   test:
     name: Build and test on ${{ matrix.os }}
     name: Build and test on ${{ matrix.os }}
@@ -79,7 +97,7 @@ jobs:
 
 
     - uses: actions/setup-go@v5
     - uses: actions/setup-go@v5
       with:
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
         check-latest: true
 
 
     - name: Build nebula
     - name: Build nebula

+ 3 - 1
.gitignore

@@ -5,7 +5,8 @@
 /nebula-darwin
 /nebula-darwin
 /nebula.exe
 /nebula.exe
 /nebula-cert.exe
 /nebula-cert.exe
-/coverage.out
+**/coverage.out
+**/cover.out
 /cpu.pprof
 /cpu.pprof
 /build
 /build
 /*.tar.gz
 /*.tar.gz
@@ -13,5 +14,6 @@
 **.crt
 **.crt
 **.key
 **.key
 **.pem
 **.pem
+**.pub
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt

+ 14 - 5
Makefile

@@ -40,7 +40,7 @@ ALL_LINUX = linux-amd64 \
 	linux-mips64le \
 	linux-mips64le \
 	linux-mips-softfloat \
 	linux-mips-softfloat \
 	linux-riscv64 \
 	linux-riscv64 \
-        linux-loong64
+	linux-loong64
 
 
 ALL_FREEBSD = freebsd-amd64 \
 ALL_FREEBSD = freebsd-amd64 \
 	freebsd-arm64
 	freebsd-arm64
@@ -63,7 +63,7 @@ ALL = $(ALL_LINUX) \
 e2e:
 e2e:
 	$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
 	$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
 
 
-e2ev: TEST_FLAGS = -v
+e2ev: TEST_FLAGS += -v
 e2ev: e2e
 e2ev: e2e
 
 
 e2evv: TEST_ENV += TEST_LOGS=1
 e2evv: TEST_ENV += TEST_LOGS=1
@@ -96,7 +96,7 @@ release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz)
 
 
 release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
 release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
 
 
-BUILD_ARGS = -trimpath
+BUILD_ARGS += -trimpath
 
 
 bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
 bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
 	mv $? .
 	mv $? .
@@ -116,6 +116,10 @@ bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert
 bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
 bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
 	mv $? .
 	mv $? .
 
 
+bin-pkcs11: BUILD_ARGS += -tags pkcs11
+bin-pkcs11: CGO_ENABLED = 1
+bin-pkcs11: bin
+
 bin:
 bin:
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
@@ -133,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
 # boringcrypto
 # boringcrypto
 build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
 build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
+build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
 
 
 build/%/nebula: .FORCE
 build/%/nebula: .FORCE
 	GOOS=$(firstword $(subst -, , $*)) \
 	GOOS=$(firstword $(subst -, , $*)) \
@@ -166,7 +172,10 @@ test:
 	go test -v ./...
 	go test -v ./...
 
 
 test-boringcrypto:
 test-boringcrypto:
-	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
+	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
+
+test-pkcs11:
+	CGO_ENABLED=1 go test -v -tags pkcs11 ./...
 
 
 test-cov-html:
 test-cov-html:
 	go test -coverprofile=coverage.out
 	go test -coverprofile=coverage.out
@@ -189,7 +198,7 @@ bench-cpu-long:
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 
 
-proto: nebula.pb.go cert/cert.pb.go
+proto: nebula.pb.go cert/cert_v1.pb.go
 
 
 nebula.pb.go: nebula.proto .FORCE
 nebula.pb.go: nebula.proto .FORCE
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster

+ 1 - 1
README.md

@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
 
 
 You can read more about Nebula [here](https://medium.com/p/884110a5579).
 You can read more about Nebula [here](https://medium.com/p/884110a5579).
 
 
-You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
+You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
 
 
 ## Supported Platforms
 ## Supported Platforms
 
 

+ 25 - 12
allow_list.go

@@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 
 		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
 		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
 
 
-		// TODO: should we error on duplicate CIDRs in the config?
 		tree.Insert(ipNet, value)
 		tree.Insert(ipNet, value)
 
 
 		maskBits := ipNet.Bits()
 		maskBits := ipNet.Bits()
@@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
 	return remoteAllowRanges, nil
 	return remoteAllowRanges, nil
 }
 }
 
 
-func (al *AllowList) Allow(ip netip.Addr) bool {
+func (al *AllowList) Allow(addr netip.Addr) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
 
 
-	result, _ := al.cidrTree.Lookup(ip)
+	result, _ := al.cidrTree.Lookup(addr)
 	return result
 	return result
 }
 }
 
 
-func (al *LocalAllowList) Allow(ip netip.Addr) bool {
+func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 }
 
 
 func (al *LocalAllowList) AllowName(name string) bool {
 func (al *LocalAllowList) AllowName(name string) bool {
@@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 	return !al.nameRules[0].Allow
 }
 }
 
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
+func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(vpnAddr)
 }
 }
 
 
-func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
-	if !al.getInsideAllowList(vpnIp).Allow(ip) {
+func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
+	if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
 		return false
 		return false
 	}
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 }
 
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
+func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
+	if !al.AllowList.Allow(udpAddr) {
+		return false
+	}
+
+	for _, vpnAddr := range vpnAddrs {
+		if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
 	if al.insideAllowLists != nil {
-		inside, ok := al.insideAllowLists.Lookup(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnAddr)
 		if ok {
 		if ok {
 			return inside
 			return inside
 		}
 		}

+ 34 - 25
calculated_remote.go

@@ -21,7 +21,11 @@ type calculatedRemote struct {
 	port  uint32
 	port  uint32
 }
 }
 
 
-func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
+		return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
+	}
+
 	masked := maskCidr.Masked()
 	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
 		return nil, fmt.Errorf("invalid port: %d", port)
@@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 }
 
 
-func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
-	// Combine the masked bytes of the "mask" IP with the unmasked bytes
-	// of the overlay IP
-	if c.ipNet.Addr().Is4() {
-		return c.apply4(ip)
-	}
-	return c.apply6(ip)
-}
-
-func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK this can be less crappy
+func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
+	// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
 	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
 	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
 	mask := binary.BigEndian.Uint32(maskb[:])
 	mask := binary.BigEndian.Uint32(maskb[:])
 
 
 	b := c.mask.Addr().As4()
 	b := c.mask.Addr().As4()
-	maskIp := binary.BigEndian.Uint32(b[:])
+	maskAddr := binary.BigEndian.Uint32(b[:])
 
 
-	b = ip.As4()
-	intIp := binary.BigEndian.Uint32(b[:])
+	b = addr.As4()
+	intAddr := binary.BigEndian.Uint32(b[:])
 
 
-	return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+	return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
 }
 }
 
 
-func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK
-	panic("Can not calculate ipv6 remote addresses")
+func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
+	mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	maskAddr := c.mask.Addr().As16()
+	calcAddr := addr.As16()
+
+	ap := V6AddrPort{Port: c.port}
+
+	maskb := binary.BigEndian.Uint64(mask[:8])
+	maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
+	calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
+	ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	maskb = binary.BigEndian.Uint64(mask[8:])
+	maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
+	calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
+	ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	return &ap
 }
 }
 
 
 func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
 func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 		}
 
 
-		//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
-		entry, err := newCalculatedRemotesListFromConfig(rawValue)
+		entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
 		}
@@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 	return calculatedRemotes, nil
 	return calculatedRemotes, nil
 }
 }
 
 
-func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
+func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
 	rawList, ok := raw.([]any)
 	rawList, ok := raw.([]any)
 	if !ok {
 	if !ok {
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 
 
 	var l []*calculatedRemote
 	var l []*calculatedRemote
 	for _, e := range rawList {
 	for _, e := range rawList {
-		c, err := newCalculatedRemotesEntryFromConfig(e)
+		c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 		}
 		}
@@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 	return l, nil
 	return l, nil
 }
 }
 
 
-func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
+func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
 	rawMap, ok := raw.(map[any]any)
 	rawMap, ok := raw.(map[any]any)
 	if !ok {
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
 		return nil, fmt.Errorf("invalid type: %T", raw)
@@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 	}
 
 
-	return newCalculatedRemote(maskCidr, port)
+	return newCalculatedRemote(cidr, maskCidr, port)
 }
 }

+ 61 - 5
calculated_remote_test.go

@@ -9,10 +9,9 @@ import (
 )
 )
 
 
 func TestCalculatedRemoteApply(t *testing.T) {
 func TestCalculatedRemoteApply(t *testing.T) {
-	ipNet, err := netip.ParsePrefix("192.168.1.0/24")
-	require.NoError(t, err)
-
-	c, err := newCalculatedRemote(ipNet, 4242)
+	// Test v4 addresses
+	ipNet := netip.MustParsePrefix("192.168.1.0/24")
+	c, err := newCalculatedRemote(ipNet, ipNet, 4242)
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	input, err := netip.ParseAddr("10.0.10.182")
 	input, err := netip.ParseAddr("10.0.10.182")
@@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	expected, err := netip.ParseAddr("192.168.1.182")
 	expected, err := netip.ParseAddr("192.168.1.182")
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
+	assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
+
+	// Test v6 addresses
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+}
+
+func Test_newCalculatedRemote(t *testing.T) {
+	c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
 }
 }

+ 1 - 1
cert/Makefile

@@ -1,7 +1,7 @@
 GO111MODULE = on
 GO111MODULE = on
 export GO111MODULE
 export GO111MODULE
 
 
-cert.pb.go: cert.proto .FORCE
+cert_v1.pb.go: cert_v1.proto .FORCE
 	go build google.golang.org/protobuf/cmd/protoc-gen-go
 	go build google.golang.org/protobuf/cmd/protoc-gen-go
 	PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
 	PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
 	rm protoc-gen-go
 	rm protoc-gen-go

+ 15 - 4
cert/README.md

@@ -2,14 +2,25 @@
 
 
 This is a library for interacting with `nebula` style certificates and authorities.
 This is a library for interacting with `nebula` style certificates and authorities.
 
 
-A `protobuf` definition of the certificate format is also included
+There are now 2 versions of `nebula` certificates:
 
 
-### Compiling the protobuf definition
+## v1
 
 
-Make sure you have `protoc` installed.
+This version is deprecated.
+
+A `protobuf` definition of the certificate format is included at `cert_v1.proto`
+
+To compile the definition you will need `protoc` installed.
 
 
 To compile for `go` with the same version of protobuf specified in go.mod:
 To compile for `go` with the same version of protobuf specified in go.mod:
 
 
 ```bash
 ```bash
-make
+make proto
 ```
 ```
+
+## v2
+
+This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
+future certificate changes better than v1.
+
+`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

+ 52 - 0
cert/asn1.go

@@ -0,0 +1,52 @@
+package cert
+
+import (
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+)
+
+// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
+// https://github.com/golang/go/issues/64811#issuecomment-1944446920
+func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0] > 0
+		return true
+	}
+
+	return false
+}
+
+// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
+// Similar issue as with readOptionalASN1Boolean
+func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0]
+		return true
+	}
+
+	return false
+}

+ 0 - 140
cert/ca.go

@@ -1,140 +0,0 @@
-package cert
-
-import (
-	"errors"
-	"fmt"
-	"strings"
-	"time"
-)
-
-type NebulaCAPool struct {
-	CAs           map[string]*NebulaCertificate
-	certBlocklist map[string]struct{}
-}
-
-// NewCAPool creates a CAPool
-func NewCAPool() *NebulaCAPool {
-	ca := NebulaCAPool{
-		CAs:           make(map[string]*NebulaCertificate),
-		certBlocklist: make(map[string]struct{}),
-	}
-
-	return &ca
-}
-
-// NewCAPoolFromBytes will create a new CA pool from the provided
-// input bytes, which must be a PEM-encoded set of nebula certificates.
-// If the pool contains any expired certificates, an ErrExpired will be
-// returned along with the pool. The caller must handle any such errors.
-func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
-	pool := NewCAPool()
-	var err error
-	var expired bool
-	for {
-		caPEMs, err = pool.AddCACertificate(caPEMs)
-		if errors.Is(err, ErrExpired) {
-			expired = true
-			err = nil
-		}
-		if err != nil {
-			return nil, err
-		}
-		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
-			break
-		}
-	}
-
-	if expired {
-		return pool, ErrExpired
-	}
-
-	return pool, nil
-}
-
-// AddCACertificate verifies a Nebula CA certificate and adds it to the pool
-// Only the first pem encoded object will be consumed, any remaining bytes are returned.
-// Parsed certificates will be verified and must be a CA
-func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
-	c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
-	if err != nil {
-		return pemBytes, err
-	}
-
-	if !c.Details.IsCA {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
-	}
-
-	if !c.CheckSignature(c.Details.PublicKey) {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
-	}
-
-	sum, err := c.Sha256Sum()
-	if err != nil {
-		return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
-	}
-
-	ncp.CAs[sum] = c
-	if c.Expired(time.Now()) {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
-	}
-
-	return pemBytes, nil
-}
-
-// BlocklistFingerprint adds a cert fingerprint to the blocklist
-func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
-	ncp.certBlocklist[f] = struct{}{}
-}
-
-// ResetCertBlocklist removes all previously blocklisted cert fingerprints
-func (ncp *NebulaCAPool) ResetCertBlocklist() {
-	ncp.certBlocklist = make(map[string]struct{})
-}
-
-// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated
-// automatically if you manually change any fields in the NebulaCertificate.
-func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
-	return ncp.isBlocklistedWithCache(c, false)
-}
-
-// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
-func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool {
-	h, err := c.sha256SumWithCache(useCache)
-	if err != nil {
-		return true
-	}
-
-	if _, ok := ncp.certBlocklist[h]; ok {
-		return true
-	}
-
-	return false
-}
-
-// GetCAForCert attempts to return the signing certificate for the provided certificate.
-// No signature validation is performed
-func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
-	if c.Details.Issuer == "" {
-		return nil, fmt.Errorf("no issuer in certificate")
-	}
-
-	signer, ok := ncp.CAs[c.Details.Issuer]
-	if ok {
-		return signer, nil
-	}
-
-	return nil, fmt.Errorf("could not find ca for the certificate")
-}
-
-// GetFingerprints returns an array of trusted CA fingerprints
-func (ncp *NebulaCAPool) GetFingerprints() []string {
-	fp := make([]string, len(ncp.CAs))
-
-	i := 0
-	for k := range ncp.CAs {
-		fp[i] = k
-		i++
-	}
-
-	return fp
-}

+ 296 - 0
cert/ca_pool.go

@@ -0,0 +1,296 @@
+package cert
+
+import (
+	"errors"
+	"fmt"
+	"net/netip"
+	"slices"
+	"strings"
+	"time"
+)
+
+type CAPool struct {
+	CAs           map[string]*CachedCertificate
+	certBlocklist map[string]struct{}
+}
+
+// NewCAPool creates an empty CAPool
+func NewCAPool() *CAPool {
+	ca := CAPool{
+		CAs:           make(map[string]*CachedCertificate),
+		certBlocklist: make(map[string]struct{}),
+	}
+
+	return &ca
+}
+
+// NewCAPoolFromPEM will create a new CA pool from the provided
+// input bytes, which must be a PEM-encoded set of nebula certificates.
+// If the pool contains any expired certificates, an ErrExpired will be
+// returned along with the pool. The caller must handle any such errors.
+func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
+	pool := NewCAPool()
+	var err error
+	var expired bool
+	for {
+		caPEMs, err = pool.AddCAFromPEM(caPEMs)
+		if errors.Is(err, ErrExpired) {
+			expired = true
+			err = nil
+		}
+		if err != nil {
+			return nil, err
+		}
+		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
+			break
+		}
+	}
+
+	if expired {
+		return pool, ErrExpired
+	}
+
+	return pool, nil
+}
+
+// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
+// Only the first pem encoded object will be consumed, any remaining bytes are returned.
+// Parsed certificates will be verified and must be a CA
+func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
+	c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
+	if err != nil {
+		return pemBytes, err
+	}
+
+	err = ncp.AddCA(c)
+	if err != nil {
+		return pemBytes, err
+	}
+
+	return pemBytes, nil
+}
+
+// AddCA verifies a Nebula CA certificate and adds it to the pool.
+func (ncp *CAPool) AddCA(c Certificate) error {
+	if !c.IsCA() {
+		return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
+	}
+
+	if !c.CheckSignature(c.PublicKey()) {
+		return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
+	}
+
+	sum, err := c.Fingerprint()
+	if err != nil {
+		return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
+	}
+
+	cc := &CachedCertificate{
+		Certificate:    c,
+		Fingerprint:    sum,
+		InvertedGroups: make(map[string]struct{}),
+	}
+
+	for _, g := range c.Groups() {
+		cc.InvertedGroups[g] = struct{}{}
+	}
+
+	ncp.CAs[sum] = cc
+
+	if c.Expired(time.Now()) {
+		return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
+	}
+
+	return nil
+}
+
+// BlocklistFingerprint adds a cert fingerprint to the blocklist
+func (ncp *CAPool) BlocklistFingerprint(f string) {
+	ncp.certBlocklist[f] = struct{}{}
+}
+
+// ResetCertBlocklist removes all previously blocklisted cert fingerprints
+func (ncp *CAPool) ResetCertBlocklist() {
+	ncp.certBlocklist = make(map[string]struct{})
+}
+
+// IsBlocklisted tests the provided fingerprint against the pools blocklist.
+// Returns true if the fingerprint is blocked.
+func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
+	if _, ok := ncp.certBlocklist[fingerprint]; ok {
+		return true
+	}
+
+	return false
+}
+
+// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
+// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
+// to increase performance.
+func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
+	if c == nil {
+		return nil, fmt.Errorf("no certificate")
+	}
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
+	}
+
+	signer, err := ncp.verify(c, now, fp, "")
+	if err != nil {
+		return nil, err
+	}
+
+	cc := CachedCertificate{
+		Certificate:       c,
+		InvertedGroups:    make(map[string]struct{}),
+		Fingerprint:       fp,
+		signerFingerprint: signer.Fingerprint,
+	}
+
+	for _, g := range c.Groups() {
+		cc.InvertedGroups[g] = struct{}{}
+	}
+
+	return &cc, nil
+}
+
+// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
+// is a cheaper operation to perform as a result.
+func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
+	_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
+	return err
+}
+
+func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
+	if ncp.IsBlocklisted(certFp) {
+		return nil, ErrBlockListed
+	}
+
+	signer, err := ncp.GetCAForCert(c)
+	if err != nil {
+		return nil, err
+	}
+
+	if signer.Certificate.Expired(now) {
+		return nil, ErrRootExpired
+	}
+
+	if c.Expired(now) {
+		return nil, ErrExpired
+	}
+
+	// If we are checking a cached certificate then we can bail early here
+	// Either the root is no longer trusted or everything is fine
+	if len(signerFp) > 0 {
+		if signerFp != signer.Fingerprint {
+			return nil, ErrFingerprintMismatch
+		}
+		return signer, nil
+	}
+	if !c.CheckSignature(signer.Certificate.PublicKey()) {
+		return nil, ErrSignatureMismatch
+	}
+
+	err = CheckCAConstraints(signer.Certificate, c)
+	if err != nil {
+		return nil, err
+	}
+
+	return signer, nil
+}
+
+// GetCAForCert attempts to return the signing certificate for the provided certificate.
+// No signature validation is performed
+func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
+	issuer := c.Issuer()
+	if issuer == "" {
+		return nil, fmt.Errorf("no issuer in certificate")
+	}
+
+	signer, ok := ncp.CAs[issuer]
+	if ok {
+		return signer, nil
+	}
+
+	return nil, ErrCaNotFound
+}
+
+// GetFingerprints returns an array of trusted CA fingerprints
+func (ncp *CAPool) GetFingerprints() []string {
+	fp := make([]string, len(ncp.CAs))
+
+	i := 0
+	for k := range ncp.CAs {
+		fp[i] = k
+		i++
+	}
+
+	return fp
+}
+
+// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
+func CheckCAConstraints(signer Certificate, sub Certificate) error {
+	return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
+}
+
+// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
+func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
+	// Make sure this cert isn't valid after the root
+	if notAfter.After(signer.NotAfter()) {
+		return fmt.Errorf("certificate expires after signing certificate")
+	}
+
+	// Make sure this cert wasn't valid before the root
+	if notBefore.Before(signer.NotBefore()) {
+		return fmt.Errorf("certificate is valid before the signing certificate")
+	}
+
+	// If the signer has a limited set of groups make sure the cert only contains a subset
+	signerGroups := signer.Groups()
+	if len(signerGroups) > 0 {
+		for _, g := range groups {
+			if !slices.Contains(signerGroups, g) {
+				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
+			}
+		}
+	}
+
+	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
+	signingNetworks := signer.Networks()
+	if len(signingNetworks) > 0 {
+		for _, certNetwork := range networks {
+			found := false
+			for _, signingNetwork := range signingNetworks {
+				if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
+			}
+		}
+	}
+
+	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
+	signingUnsafeNetworks := signer.UnsafeNetworks()
+	if len(signingUnsafeNetworks) > 0 {
+		for _, certUnsafeNetwork := range unsafeNetworks {
+			found := false
+			for _, caNetwork := range signingUnsafeNetworks {
+				if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
+			}
+		}
+	}
+
+	return nil
+}

+ 559 - 0
cert/ca_pool_test.go

@@ -0,0 +1,559 @@
+package cert
+
+import (
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestNewCAPoolFromBytes(t *testing.T) {
+	noNewLines := `
+# Current provisional, Remove once everything moves over to the real root.
+-----BEGIN NEBULA CERTIFICATE-----
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
+-----END NEBULA CERTIFICATE-----
+# root-ca01
+-----BEGIN NEBULA CERTIFICATE-----
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
+-----END NEBULA CERTIFICATE-----
+`
+
+	withNewLines := `
+# Current provisional, Remove once everything moves over to the real root.
+
+-----BEGIN NEBULA CERTIFICATE-----
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
+-----END NEBULA CERTIFICATE-----
+
+# root-ca01
+
+
+-----BEGIN NEBULA CERTIFICATE-----
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
+-----END NEBULA CERTIFICATE-----
+
+`
+
+	expired := `
+# expired certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
+7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
+Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
+-----END NEBULA CERTIFICATE-----
+`
+
+	p256 := `
+# p256 certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
+k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
+75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
+-----END NEBULA CERTIFICATE-----
+`
+
+	rootCA := certificateV1{
+		details: detailsV1{
+			name: "nebula root ca",
+		},
+	}
+
+	rootCA01 := certificateV1{
+		details: detailsV1{
+			name: "nebula root ca 01",
+		},
+	}
+
+	rootCAP256 := certificateV1{
+		details: detailsV1{
+			name: "nebula P256 test",
+		},
+	}
+
+	p, err := NewCAPoolFromPEM([]byte(noNewLines))
+	assert.Nil(t, err)
+	assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+
+	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
+	assert.Nil(t, err)
+	assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+
+	// expired cert, no valid certs
+	ppp, err := NewCAPoolFromPEM([]byte(expired))
+	assert.Equal(t, ErrExpired, err)
+	assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
+
+	// expired cert, with valid certs
+	pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
+	assert.Equal(t, ErrExpired, err)
+	assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+	assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
+	assert.Equal(t, len(pppp.CAs), 3)
+
+	ppppp, err := NewCAPoolFromPEM([]byte(p256))
+	assert.Nil(t, err)
+	assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
+	assert.Equal(t, len(ppppp.CAs), 1)
+}
+
+func TestCertificateV1_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}

+ 104 - 968
cert/cert.go

@@ -1,1029 +1,165 @@
 package cert
 package cert
 
 
 import (
 import (
-	"bytes"
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/ed25519"
-	"crypto/elliptic"
-	"crypto/rand"
-	"crypto/sha256"
-	"encoding/binary"
-	"encoding/hex"
-	"encoding/json"
-	"encoding/pem"
-	"errors"
 	"fmt"
 	"fmt"
-	"math"
-	"math/big"
-	"net"
-	"sync/atomic"
+	"net/netip"
 	"time"
 	"time"
-
-	"golang.org/x/crypto/curve25519"
-	"google.golang.org/protobuf/proto"
 )
 )
 
 
-const publicKeyLen = 32
+type Version uint8
 
 
 const (
 const (
-	CertBanner                       = "NEBULA CERTIFICATE"
-	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
-	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
-	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
-	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
-	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
-
-	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
-	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
-	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
-	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+	VersionPre1 Version = 0
+	Version1    Version = 1
+	Version2    Version = 2
 )
 )
 
 
-type NebulaCertificate struct {
-	Details   NebulaCertificateDetails
-	Signature []byte
+type Certificate interface {
+	// Version defines the underlying certificate structure and wire protocol version
+	// Version1 certificates are ipv4 only and uses protobuf serialization
+	// Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization
+	Version() Version
 
 
-	// the cached hex string of the calculated sha256sum
-	// for VerifyWithCache
-	sha256sum atomic.Pointer[string]
+	// Name is the human-readable name that identifies this certificate.
+	Name() string
 
 
-	// the cached public key bytes if they were verified as the signer
-	// for VerifyWithCache
-	signatureVerified atomic.Pointer[[]byte]
-}
+	// Networks is a list of ip addresses and network sizes assigned to this certificate.
+	// If IsCA is true then certificates signed by this CA can only have ip addresses and
+	// networks that are contained by an entry in this list.
+	Networks() []netip.Prefix
 
 
-type NebulaCertificateDetails struct {
-	Name      string
-	Ips       []*net.IPNet
-	Subnets   []*net.IPNet
-	Groups    []string
-	NotBefore time.Time
-	NotAfter  time.Time
-	PublicKey []byte
-	IsCA      bool
-	Issuer    string
+	// UnsafeNetworks is a list of networks that this host can act as an unsafe router for.
+	// If IsCA is true then certificates signed by this CA can only have networks that are
+	// contained by an entry in this list.
+	UnsafeNetworks() []netip.Prefix
 
 
-	// Map of groups for faster lookup
-	InvertedGroups map[string]struct{}
+	// Groups is a list of identities that can be used to write more general firewall rule
+	// definitions.
+	// If IsCA is true then certificates signed by this CA can only use groups that are
+	// in this list.
+	Groups() []string
 
 
-	Curve Curve
-}
+	// IsCA signifies if this is a certificate authority (true) or a host certificate (false).
+	// It is invalid to use a CA certificate as a host certificate.
+	IsCA() bool
 
 
-type NebulaEncryptedData struct {
-	EncryptionMetadata NebulaEncryptionMetadata
-	Ciphertext         []byte
-}
+	// NotBefore is the time at which this certificate becomes valid.
+	// If IsCA is true then certificate signed by this CA can not have a time before this.
+	NotBefore() time.Time
 
 
-type NebulaEncryptionMetadata struct {
-	EncryptionAlgorithm string
-	Argon2Parameters    Argon2Parameters
-}
+	// NotAfter is the time at which this certificate becomes invalid.
+	// If IsCA is true then certificate signed by this CA can not have a time after this.
+	NotAfter() time.Time
 
 
-type m map[string]interface{}
+	// Issuer is the fingerprint of the CA that signed this certificate.
+	// If IsCA is true then this will be empty.
+	Issuer() string
 
 
-// Returned if we try to unmarshal an encrypted private key without a passphrase
-var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
+	// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
+	PublicKey() []byte
 
 
-// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert
-func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
-	if len(b) == 0 {
-		return nil, fmt.Errorf("nil byte array")
-	}
-	var rc RawNebulaCertificate
-	err := proto.Unmarshal(b, &rc)
-	if err != nil {
-		return nil, err
-	}
-
-	if rc.Details == nil {
-		return nil, fmt.Errorf("encoded Details was nil")
-	}
-
-	if len(rc.Details.Ips)%2 != 0 {
-		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
-	}
-
-	if len(rc.Details.Subnets)%2 != 0 {
-		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
-	}
+	// Curve identifies which curve was used for the PublicKey and Signature.
+	Curve() Curve
 
 
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           rc.Details.Name,
-			Groups:         make([]string, len(rc.Details.Groups)),
-			Ips:            make([]*net.IPNet, len(rc.Details.Ips)/2),
-			Subnets:        make([]*net.IPNet, len(rc.Details.Subnets)/2),
-			NotBefore:      time.Unix(rc.Details.NotBefore, 0),
-			NotAfter:       time.Unix(rc.Details.NotAfter, 0),
-			PublicKey:      make([]byte, len(rc.Details.PublicKey)),
-			IsCA:           rc.Details.IsCA,
-			InvertedGroups: make(map[string]struct{}),
-			Curve:          rc.Details.Curve,
-		},
-		Signature: make([]byte, len(rc.Signature)),
-	}
+	// Signature is the cryptographic seal for all the details of this certificate.
+	// CheckSignature can be used to verify that the details of this certificate are valid.
+	Signature() []byte
 
 
-	copy(nc.Signature, rc.Signature)
-	copy(nc.Details.Groups, rc.Details.Groups)
-	nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer)
+	// CheckSignature will check that the certificate Signature() matches the
+	// computed signature. A true result means this certificate has not been tampered with.
+	CheckSignature(signingPublicKey []byte) bool
 
 
-	if len(rc.Details.PublicKey) < publicKeyLen {
-		return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
-	}
-	copy(nc.Details.PublicKey, rc.Details.PublicKey)
+	// Fingerprint returns the hex encoded sha256 sum of the certificate.
+	// This acts as a unique fingerprint and can be used to blocklist certificates.
+	Fingerprint() (string, error)
 
 
-	for i, rawIp := range rc.Details.Ips {
-		if i%2 == 0 {
-			nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)}
-		} else {
-			nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp))
-		}
-	}
+	// Expired tests if the certificate is valid for the provided time.
+	Expired(t time.Time) bool
 
 
-	for i, rawIp := range rc.Details.Subnets {
-		if i%2 == 0 {
-			nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)}
-		} else {
-			nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp))
-		}
-	}
+	// VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key.
+	VerifyPrivateKey(curve Curve, privateKey []byte) error
 
 
-	for _, g := range rc.Details.Groups {
-		nc.Details.InvertedGroups[g] = struct{}{}
-	}
+	// Marshal will return the byte representation of this certificate
+	// This is primarily the format transmitted on the wire.
+	Marshal() ([]byte, error)
 
 
-	return &nc, nil
-}
+	// MarshalForHandshakes prepares the bytes needed to use directly in a handshake
+	MarshalForHandshakes() ([]byte, error)
 
 
-// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data
-// or an error on failure
-func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
-	p, r := pem.Decode(b)
-	if p == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if p.Type != CertBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
-	}
-	nc, err := UnmarshalNebulaCertificate(p.Bytes)
-	return nc, r, err
-}
+	// MarshalPEM will return a PEM encoded representation of this certificate
+	// This is primarily the format stored on disk
+	MarshalPEM() ([]byte, error)
 
 
-func MarshalPrivateKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
+	// MarshalJSON will return the json representation of this certificate
+	MarshalJSON() ([]byte, error)
 
 
-func MarshalSigningPrivateKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
+	// String will return a human-readable representation of this certificate
+	String() string
 
 
-// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key
-func MarshalX25519PrivateKey(b []byte) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
+	// Copy creates a copy of the certificate
+	Copy() Certificate
 }
 }
 
 
-// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key
-func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
+// CachedCertificate represents a verified certificate with some cached fields to improve
+// performance.
+type CachedCertificate struct {
+	Certificate       Certificate
+	InvertedGroups    map[string]struct{}
+	Fingerprint       string
+	signerFingerprint string
 }
 }
 
 
-func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var expectedLen int
-	var curve Curve
-	switch k.Type {
-	case X25519PrivateKeyBanner:
-		expectedLen = 32
-		curve = Curve_CURVE25519
-	case P256PrivateKeyBanner:
-		expectedLen = 32
-		curve = Curve_P256
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner")
-	}
-	if len(k.Bytes) != expectedLen {
-		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
-	}
-	return k.Bytes, r, curve, nil
-}
-
-func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var curve Curve
-	switch k.Type {
-	case EncryptedEd25519PrivateKeyBanner:
-		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
-	case EncryptedECDSAP256PrivateKeyBanner:
-		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
-	case Ed25519PrivateKeyBanner:
-		curve = Curve_CURVE25519
-		if len(k.Bytes) != ed25519.PrivateKeySize {
-			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
-		}
-	case ECDSAP256PrivateKeyBanner:
-		curve = Curve_P256
-		if len(k.Bytes) != 32 {
-			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
-		}
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
-	}
-	return k.Bytes, r, curve, nil
-}
-
-// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
-func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
-	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
-	if err != nil {
-		return nil, err
-	}
-
-	b, err = proto.Marshal(&RawNebulaEncryptedData{
-		EncryptionMetadata: &RawNebulaEncryptionMetadata{
-			EncryptionAlgorithm: "AES-256-GCM",
-			Argon2Parameters: &RawNebulaArgon2Parameters{
-				Version:     kdfParams.version,
-				Memory:      kdfParams.Memory,
-				Parallelism: uint32(kdfParams.Parallelism),
-				Iterations:  kdfParams.Iterations,
-				Salt:        kdfParams.salt,
-			},
-		},
-		Ciphertext: ciphertext,
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
-	default:
-		return nil, fmt.Errorf("invalid curve: %v", curve)
-	}
-}
-
-// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
-// or an error on failure
-func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != X25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner")
-	}
-	if len(k.Bytes) != publicKeyLen {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b
-// or an error on failure
-func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-
-	if k.Type == EncryptedEd25519PrivateKeyBanner {
-		return nil, r, ErrPrivateKeyEncrypted
-	} else if k.Type != Ed25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner")
-	}
-
-	if len(k.Bytes) != ed25519.PrivateKeySize {
-		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
-// protobuf-generated struct.
-func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
-	if len(b) == 0 {
-		return nil, fmt.Errorf("nil byte array")
-	}
-	var rned RawNebulaEncryptedData
-	err := proto.Unmarshal(b, &rned)
-	if err != nil {
-		return nil, err
-	}
-
-	if rned.EncryptionMetadata == nil {
-		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
-	}
-
-	if rned.EncryptionMetadata.Argon2Parameters == nil {
-		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
-	}
-
-	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
-	if err != nil {
-		return nil, err
-	}
-
-	ned := NebulaEncryptedData{
-		EncryptionMetadata: NebulaEncryptionMetadata{
-			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
-			Argon2Parameters:    *params,
-		},
-		Ciphertext: rned.Ciphertext,
-	}
-
-	return &ned, nil
-}
-
-func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
-	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
-		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
-	}
-	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
-		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
-	}
-	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
-		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
-	}
-	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
-		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
-	}
-
-	return &Argon2Parameters{
-		version:     rune(params.Version),
-		Memory:      uint32(params.Memory),
-		Parallelism: uint8(params.Parallelism),
-		Iterations:  uint32(params.Iterations),
-		salt:        params.Salt,
-	}, nil
-
-}
-
-// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
-// the given passphrase, returning any other bytes b or an error on failure
-func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
-	var curve Curve
-
-	k, r := pem.Decode(b)
-	if k == nil {
-		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-
-	switch k.Type {
-	case EncryptedEd25519PrivateKeyBanner:
-		curve = Curve_CURVE25519
-	case EncryptedECDSAP256PrivateKeyBanner:
-		curve = Curve_P256
-	default:
-		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
-	}
-
-	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
-	if err != nil {
-		return curve, nil, r, err
-	}
-
-	var bytes []byte
-	switch ned.EncryptionMetadata.EncryptionAlgorithm {
-	case "AES-256-GCM":
-		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
-		if err != nil {
-			return curve, nil, r, err
-		}
-	default:
-		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
-	}
-
-	switch curve {
-	case Curve_CURVE25519:
-		if len(bytes) != ed25519.PrivateKeySize {
-			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
-		}
-	case Curve_P256:
-		if len(bytes) != 32 {
-			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
-		}
-	}
-
-	return curve, bytes, r, nil
-}
-
-func MarshalPublicKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
-
-// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
-func MarshalX25519PublicKey(b []byte) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
-}
-
-// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key
-func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key})
-}
-
-func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var expectedLen int
-	var curve Curve
-	switch k.Type {
-	case X25519PublicKeyBanner:
-		expectedLen = 32
-		curve = Curve_CURVE25519
-	case P256PublicKeyBanner:
-		// Uncompressed
-		expectedLen = 65
-		curve = Curve_P256
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner")
-	}
-	if len(k.Bytes) != expectedLen {
-		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
-	}
-	return k.Bytes, r, curve, nil
-}
-
-// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b
-// or an error on failure
-func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != X25519PublicKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner")
-	}
-	if len(k.Bytes) != publicKeyLen {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b
-// or an error on failure
-func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != Ed25519PublicKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner")
-	}
-	if len(k.Bytes) != ed25519.PublicKeySize {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key")
-	}
-
-	return k.Bytes, r, nil
+func (cc *CachedCertificate) String() string {
+	return cc.Certificate.String()
 }
 }
 
 
-// Sign signs a nebula cert with the provided private key
-func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error {
-	if curve != nc.Details.Curve {
-		return fmt.Errorf("curve in cert and private key supplied don't match")
+// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
+// Handshakes save space by placing the peers public key in a different part of the packet, we have to
+// reassemble the actual certificate structure with that in mind.
+func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
+	if publicKey == nil {
+		return nil, ErrNoPeerStaticKey
 	}
 	}
 
 
-	b, err := proto.Marshal(nc.getRawDetails())
-	if err != nil {
-		return err
+	if rawCertBytes == nil {
+		return nil, ErrNoPayload
 	}
 	}
 
 
-	var sig []byte
-
-	switch curve {
-	case Curve_CURVE25519:
-		signer := ed25519.PrivateKey(key)
-		sig = ed25519.Sign(signer, b)
-	case Curve_P256:
-		signer := &ecdsa.PrivateKey{
-			PublicKey: ecdsa.PublicKey{
-				Curve: elliptic.P256(),
-			},
-			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
-			D: new(big.Int).SetBytes(key),
-		}
-		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
-		signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
-
-		// We need to hash first for ECDSA
-		// - https://pkg.go.dev/crypto/ecdsa#SignASN1
-		hashed := sha256.Sum256(b)
-		sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
-		if err != nil {
-			return err
-		}
-	default:
-		return fmt.Errorf("invalid curve: %s", nc.Details.Curve)
-	}
-
-	nc.Signature = sig
-	return nil
-}
-
-// CheckSignature verifies the signature against the provided public key
-func (nc *NebulaCertificate) CheckSignature(key []byte) bool {
-	b, err := proto.Marshal(nc.getRawDetails())
+	c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
 	if err != nil {
 	if err != nil {
-		return false
-	}
-	switch nc.Details.Curve {
-	case Curve_CURVE25519:
-		return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature)
-	case Curve_P256:
-		x, y := elliptic.Unmarshal(elliptic.P256(), key)
-		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
-		hashed := sha256.Sum256(b)
-		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature)
-	default:
-		return false
-	}
-}
-
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool {
-	if !useCache {
-		return nc.CheckSignature(key)
-	}
-
-	if v := nc.signatureVerified.Load(); v != nil {
-		return bytes.Equal(*v, key)
-	}
-
-	verified := nc.CheckSignature(key)
-	if verified {
-		keyCopy := make([]byte, len(key))
-		copy(keyCopy, key)
-		nc.signatureVerified.Store(&keyCopy)
+		return nil, fmt.Errorf("error unmarshaling cert: %w", err)
 	}
 	}
 
 
-	return verified
-}
-
-// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false
-func (nc *NebulaCertificate) Expired(t time.Time) bool {
-	return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
-}
-
-// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
-	return nc.verify(t, ncp, false)
-}
-
-// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-//
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) {
-	return nc.verify(t, ncp, true)
-}
-
-// ResetCache resets the cache used by VerifyWithCache.
-func (nc *NebulaCertificate) ResetCache() {
-	nc.sha256sum.Store(nil)
-	nc.signatureVerified.Store(nil)
-}
-
-// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) {
-	if ncp.isBlocklistedWithCache(nc, useCache) {
-		return false, ErrBlockListed
-	}
-
-	signer, err := ncp.GetCAForCert(nc)
+	cc, err := caPool.VerifyCertificate(time.Now(), c)
 	if err != nil {
 	if err != nil {
-		return false, err
-	}
-
-	if signer.Expired(t) {
-		return false, ErrRootExpired
-	}
-
-	if nc.Expired(t) {
-		return false, ErrExpired
+		return nil, fmt.Errorf("certificate validation failed: %w", err)
 	}
 	}
 
 
-	if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) {
-		return false, ErrSignatureMismatch
-	}
-
-	if err := nc.CheckRootConstrains(signer); err != nil {
-		return false, err
-	}
-
-	return true, nil
+	return cc, nil
 }
 }
 
 
-// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets)
-func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error {
-	// Make sure this cert wasn't valid before the root
-	if signer.Details.NotAfter.Before(nc.Details.NotAfter) {
-		return fmt.Errorf("certificate expires after signing certificate")
-	}
+func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
+	var c Certificate
+	var err error
 
 
-	// Make sure this cert isn't valid after the root
-	if signer.Details.NotBefore.After(nc.Details.NotBefore) {
-		return fmt.Errorf("certificate is valid before the signing certificate")
-	}
-
-	// If the signer has a limited set of groups make sure the cert only contains a subset
-	if len(signer.Details.InvertedGroups) > 0 {
-		for _, g := range nc.Details.Groups {
-			if _, ok := signer.Details.InvertedGroups[g]; !ok {
-				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
-			}
-		}
-	}
-
-	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
-	if len(signer.Details.Ips) > 0 {
-		for _, ip := range nc.Details.Ips {
-			if !netMatch(ip, signer.Details.Ips) {
-				return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String())
-			}
-		}
-	}
-
-	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
-	if len(signer.Details.Subnets) > 0 {
-		for _, subnet := range nc.Details.Subnets {
-			if !netMatch(subnet, signer.Details.Subnets) {
-				return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet)
-			}
-		}
-	}
-
-	return nil
-}
-
-// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
-func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error {
-	if curve != nc.Details.Curve {
-		return fmt.Errorf("curve in cert and private key supplied don't match")
-	}
-	if nc.Details.IsCA {
-		switch curve {
-		case Curve_CURVE25519:
-			// the call to PublicKey below will panic slice bounds out of range otherwise
-			if len(key) != ed25519.PrivateKeySize {
-				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-			}
-
-			if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
-				return fmt.Errorf("public key in cert and private key supplied don't match")
-			}
-		case Curve_P256:
-			privkey, err := ecdh.P256().NewPrivateKey(key)
-			if err != nil {
-				return fmt.Errorf("cannot parse private key as P256")
-			}
-			pub := privkey.PublicKey().Bytes()
-			if !bytes.Equal(pub, nc.Details.PublicKey) {
-				return fmt.Errorf("public key in cert and private key supplied don't match")
-			}
-		default:
-			return fmt.Errorf("invalid curve: %s", curve)
-		}
-		return nil
-	}
-
-	var pub []byte
-	switch curve {
-	case Curve_CURVE25519:
-		var err error
-		pub, err = curve25519.X25519(key, curve25519.Basepoint)
-		if err != nil {
-			return err
-		}
-	case Curve_P256:
-		privkey, err := ecdh.P256().NewPrivateKey(key)
-		if err != nil {
-			return err
-		}
-		pub = privkey.PublicKey().Bytes()
+	switch v {
+	// Implementations must ensure the result is a valid cert!
+	case VersionPre1, Version1:
+		c, err = unmarshalCertificateV1(b, publicKey)
+	case Version2:
+		c, err = unmarshalCertificateV2(b, publicKey, curve)
 	default:
 	default:
-		return fmt.Errorf("invalid curve: %s", curve)
-	}
-	if !bytes.Equal(pub, nc.Details.PublicKey) {
-		return fmt.Errorf("public key in cert and private key supplied don't match")
-	}
-
-	return nil
-}
-
-// String will return a pretty printed representation of a nebula cert
-func (nc *NebulaCertificate) String() string {
-	if nc == nil {
-		return "NebulaCertificate {}\n"
-	}
-
-	s := "NebulaCertificate {\n"
-	s += "\tDetails {\n"
-	s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name)
-
-	if len(nc.Details.Ips) > 0 {
-		s += "\t\tIps: [\n"
-		for _, ip := range nc.Details.Ips {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tIps: []\n"
-	}
-
-	if len(nc.Details.Subnets) > 0 {
-		s += "\t\tSubnets: [\n"
-		for _, ip := range nc.Details.Subnets {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tSubnets: []\n"
-	}
-
-	if len(nc.Details.Groups) > 0 {
-		s += "\t\tGroups: [\n"
-		for _, g := range nc.Details.Groups {
-			s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tGroups: []\n"
-	}
-
-	s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore)
-	s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter)
-	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA)
-	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer)
-	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey)
-	s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve)
-	s += "\t}\n"
-	fp, err := nc.Sha256Sum()
-	if err == nil {
-		s += fmt.Sprintf("\tFingerprint: %s\n", fp)
+		//TODO: CERT-V2 make a static var
+		return nil, fmt.Errorf("unknown certificate version %d", v)
 	}
 	}
-	s += fmt.Sprintf("\tSignature: %x\n", nc.Signature)
-	s += "}"
 
 
-	return s
-}
-
-// getRawDetails marshals the raw details into protobuf ready struct
-func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails {
-	rd := &RawNebulaCertificateDetails{
-		Name:      nc.Details.Name,
-		Groups:    nc.Details.Groups,
-		NotBefore: nc.Details.NotBefore.Unix(),
-		NotAfter:  nc.Details.NotAfter.Unix(),
-		PublicKey: make([]byte, len(nc.Details.PublicKey)),
-		IsCA:      nc.Details.IsCA,
-		Curve:     nc.Details.Curve,
-	}
-
-	for _, ipNet := range nc.Details.Ips {
-		rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask))
-	}
-
-	for _, ipNet := range nc.Details.Subnets {
-		rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask))
-	}
-
-	copy(rd.PublicKey, nc.Details.PublicKey[:])
-
-	// I know, this is terrible
-	rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer)
-
-	return rd
-}
-
-// Marshal will marshal a nebula cert into a protobuf byte array
-func (nc *NebulaCertificate) Marshal() ([]byte, error) {
-	rc := RawNebulaCertificate{
-		Details:   nc.getRawDetails(),
-		Signature: nc.Signature,
-	}
-
-	return proto.Marshal(&rc)
-}
-
-// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result
-func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) {
-	b, err := nc.Marshal()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil
-}
-
-// Sha256Sum calculates a sha-256 sum of the marshaled certificate
-func (nc *NebulaCertificate) Sha256Sum() (string, error) {
-	b, err := nc.Marshal()
-	if err != nil {
-		return "", err
-	}
-
-	sum := sha256.Sum256(b)
-	return hex.EncodeToString(sum[:]), nil
-}
-
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) {
-	if !useCache {
-		return nc.Sha256Sum()
-	}
-
-	if s := nc.sha256sum.Load(); s != nil {
-		return *s, nil
-	}
-	s, err := nc.Sha256Sum()
-	if err != nil {
-		return s, err
-	}
-
-	nc.sha256sum.Store(&s)
-	return s, nil
-}
-
-func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
-	toString := func(ips []*net.IPNet) []string {
-		s := []string{}
-		for _, ip := range ips {
-			s = append(s, ip.String())
-		}
-		return s
-	}
-
-	fp, _ := nc.Sha256Sum()
-	jc := m{
-		"details": m{
-			"name":      nc.Details.Name,
-			"ips":       toString(nc.Details.Ips),
-			"subnets":   toString(nc.Details.Subnets),
-			"groups":    nc.Details.Groups,
-			"notBefore": nc.Details.NotBefore,
-			"notAfter":  nc.Details.NotAfter,
-			"publicKey": fmt.Sprintf("%x", nc.Details.PublicKey),
-			"isCa":      nc.Details.IsCA,
-			"issuer":    nc.Details.Issuer,
-			"curve":     nc.Details.Curve.String(),
-		},
-		"fingerprint": fp,
-		"signature":   fmt.Sprintf("%x", nc.Signature),
-	}
-	return json.Marshal(jc)
-}
-
-//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
-//	r, err := nc.Marshal()
-//	if err != nil {
-//		//TODO
-//		return nil
-//	}
-//
-//	c, err := UnmarshalNebulaCertificate(r)
-//	return c
-//}
-
-func (nc *NebulaCertificate) Copy() *NebulaCertificate {
-	c := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           nc.Details.Name,
-			Groups:         make([]string, len(nc.Details.Groups)),
-			Ips:            make([]*net.IPNet, len(nc.Details.Ips)),
-			Subnets:        make([]*net.IPNet, len(nc.Details.Subnets)),
-			NotBefore:      nc.Details.NotBefore,
-			NotAfter:       nc.Details.NotAfter,
-			PublicKey:      make([]byte, len(nc.Details.PublicKey)),
-			IsCA:           nc.Details.IsCA,
-			Issuer:         nc.Details.Issuer,
-			InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
-		},
-		Signature: make([]byte, len(nc.Signature)),
-	}
-
-	copy(c.Signature, nc.Signature)
-	copy(c.Details.Groups, nc.Details.Groups)
-	copy(c.Details.PublicKey, nc.Details.PublicKey)
-
-	for i, p := range nc.Details.Ips {
-		c.Details.Ips[i] = &net.IPNet{
-			IP:   make(net.IP, len(p.IP)),
-			Mask: make(net.IPMask, len(p.Mask)),
-		}
-		copy(c.Details.Ips[i].IP, p.IP)
-		copy(c.Details.Ips[i].Mask, p.Mask)
-	}
-
-	for i, p := range nc.Details.Subnets {
-		c.Details.Subnets[i] = &net.IPNet{
-			IP:   make(net.IP, len(p.IP)),
-			Mask: make(net.IPMask, len(p.Mask)),
-		}
-		copy(c.Details.Subnets[i].IP, p.IP)
-		copy(c.Details.Subnets[i].Mask, p.Mask)
-	}
-
-	for g := range nc.Details.InvertedGroups {
-		c.Details.InvertedGroups[g] = struct{}{}
-	}
-
-	return c
-}
-
-func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
-	for _, net := range rootIps {
-		if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
-			return true
-		}
-	}
-
-	return false
-}
 
 
-func maskContains(caMask, certMask net.IPMask) bool {
-	caM := maskTo4(caMask)
-	cM := maskTo4(certMask)
-	// Make sure forcing to ipv4 didn't nuke us
-	if caM == nil || cM == nil {
-		return false
+	if c.Curve() != curve {
+		return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
 	}
 	}
 
 
-	// Make sure the cert mask is not greater than the ca mask
-	for i := 0; i < len(caMask); i++ {
-		if caM[i] > cM[i] {
-			return false
-		}
-	}
-
-	return true
-}
-
-func maskTo4(ip net.IPMask) net.IPMask {
-	if len(ip) == net.IPv4len {
-		return ip
-	}
-
-	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
-		return ip[12:16]
-	}
-
-	return nil
-}
-
-func isZeros(b []byte) bool {
-	for i := 0; i < len(b); i++ {
-		if b[i] != 0 {
-			return false
-		}
-	}
-	return true
-}
-
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, net.IPv4len)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
+	return c, nil
 }
 }

+ 0 - 1230
cert/cert_test.go

@@ -1,1230 +0,0 @@
-package cert
-
-import (
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/elliptic"
-	"crypto/rand"
-	"fmt"
-	"io"
-	"net"
-	"testing"
-	"time"
-
-	"github.com/slackhq/nebula/test"
-	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-	"google.golang.org/protobuf/proto"
-)
-
-func TestMarshalingNebulaCertificate(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-
-	nc2, err := UnmarshalNebulaCertificate(b)
-	assert.Nil(t, err)
-
-	assert.Equal(t, nc.Signature, nc2.Signature)
-	assert.Equal(t, nc.Details.Name, nc2.Details.Name)
-	assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore)
-	assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter)
-	assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey)
-	assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA)
-
-	// IP byte arrays can be 4 or 16 in length so we have to go this route
-	assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips))
-	for i, wIp := range nc.Details.Ips {
-		assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String())
-	}
-
-	assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets))
-	for i, wIp := range nc.Details.Subnets {
-		assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String())
-	}
-
-	assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups)
-}
-
-func TestNebulaCertificate_Sign(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-	}
-
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	assert.Nil(t, err)
-	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(Curve_CURVE25519, priv))
-	assert.True(t, nc.CheckSignature(pub))
-
-	_, err = nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-}
-
-func TestNebulaCertificate_SignP256(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Curve:     Curve_P256,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-	}
-
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	assert.Nil(t, err)
-	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(Curve_P256, rawPriv))
-	assert.True(t, nc.CheckSignature(pub))
-
-	_, err = nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-}
-
-func TestNebulaCertificate_Expired(t *testing.T) {
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
-			NotAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
-		},
-	}
-
-	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
-	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
-	assert.False(t, nc.Expired(time.Now()))
-}
-
-func TestNebulaCertificate_MarshalJSON(t *testing.T) {
-	time.Local = time.UTC
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
-			NotAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.MarshalJSON()
-	assert.Nil(t, err)
-	assert.Equal(
-		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
-		string(b),
-	)
-}
-
-func TestNebulaCertificate_Verify(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	h, err := ca.Sha256Sum()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.CAs[h] = ca
-
-	f, err := c.Sha256Sum()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is expired")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	h, err := ca.Sha256Sum()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.CAs[h] = ca
-
-	f, err := c.Sha256Sum()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is expired")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_IPs(t *testing.T) {
-	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
-	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	// ip is outside the network
-	cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}}
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_Subnets(t *testing.T) {
-	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
-	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	// ip is outside the network
-	cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}}
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := x25519Keypair()
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(Curve_P256, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := p256Keypair()
-	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNewCAPoolFromBytes(t *testing.T) {
-	noNewLines := `
-# Current provisional, Remove once everything moves over to the real root.
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-# root-ca01
------BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
------END NEBULA CERTIFICATE-----
-`
-
-	withNewLines := `
-# Current provisional, Remove once everything moves over to the real root.
-
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-
-# root-ca01
-
-
------BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
------END NEBULA CERTIFICATE-----
-
-`
-
-	expired := `
-# expired certificate
------BEGIN NEBULA CERTIFICATE-----
-CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4
-vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie
-WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
------END NEBULA CERTIFICATE-----
-`
-
-	p256 := `
-# p256 certificate
------BEGIN NEBULA CERTIFICATE-----
-CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
-6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H
-76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
-IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
------END NEBULA CERTIFICATE-----
-`
-
-	rootCA := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula root ca",
-		},
-	}
-
-	rootCA01 := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula root ca 01",
-		},
-	}
-
-	rootCAP256 := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula P256 test",
-		},
-	}
-
-	p, err := NewCAPoolFromBytes([]byte(noNewLines))
-	assert.Nil(t, err)
-	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-
-	pp, err := NewCAPoolFromBytes([]byte(withNewLines))
-	assert.Nil(t, err)
-	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-
-	// expired cert, no valid certs
-	ppp, err := NewCAPoolFromBytes([]byte(expired))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
-
-	// expired cert, with valid certs
-	pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
-	assert.Equal(t, len(pppp.CAs), 3)
-
-	ppppp, err := NewCAPoolFromBytes([]byte(p256))
-	assert.Nil(t, err)
-	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
-	assert.Equal(t, len(ppppp.CAs), 1)
-}
-
-func appendByteSlices(b ...[]byte) []byte {
-	retSlice := []byte{}
-	for _, v := range b {
-		retSlice = append(retSlice, v...)
-	}
-	return retSlice
-}
-
-func TestUnmrshalCertPEM(t *testing.T) {
-	goodCert := []byte(`
-# A good cert
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-`)
-	badBanner := []byte(`# A bad banner
------BEGIN NOT A NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NOT A NEBULA CERTIFICATE-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
--END NEBULA CERTIFICATE----`)
-
-	certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
-
-	// Success test case
-	cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle)
-	assert.NotNil(t, cert)
-	assert.Equal(t, rest, append(badBanner, invalidPem...))
-	assert.Nil(t, err)
-
-	// Fail due to invalid banner.
-	cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
-	assert.Nil(t, cert)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
-	assert.Nil(t, cert)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalSigningPrivateKey(t *testing.T) {
-	privKey := []byte(`# A good key
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	privP256Key := []byte(`# A good key
------BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA ECDSA P256 PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NOT A NEBULA PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
--END NEBULA ED25519 PRIVATE KEY-----`)
-
-	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle)
-	assert.Len(t, k, 64)
-	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-	assert.Nil(t, err)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
-	passphrase := []byte("DO NOT USE THIS KEY")
-	privKey := []byte(`# A good key
------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
------END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A key which, once decrypted, is too short
------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
-k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
-GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
-rQr3bdH3Oy/WiYU=
------END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner (not encrypted)
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
-XgLvodMXZJuaFPssp+WwtA==
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
--END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-
-	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
-	assert.Nil(t, err)
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Len(t, k, 64)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-
-	// Fail due to short key
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-
-	// Fail due to invalid banner
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to invalid passphrase
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
-	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, []byte{})
-}
-
-func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
-	// Having proved that decryption works correctly above, we can test the
-	// encryption function produces a value which can be decrypted
-	passphrase := []byte("passphrase")
-	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
-	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
-	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
-	assert.Nil(t, err)
-
-	// Verify the "key" can be decrypted successfully
-	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
-	assert.Len(t, k, 64)
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Equal(t, rest, []byte{})
-	assert.Nil(t, err)
-
-	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
-}
-
-func TestUnmarshalPrivateKey(t *testing.T) {
-	privKey := []byte(`# A good key
------BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA X25519 PRIVATE KEY-----
-`)
-	privP256Key := []byte(`# A good key
------BEGIN NEBULA P256 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA P256 PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA X25519 PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA X25519 PRIVATE KEY-----`)
-
-	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalPrivateKey(keyBundle)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-	assert.Nil(t, err)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalEd25519PublicKey(t *testing.T) {
-	pubKey := []byte(`# A good key
------BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA ED25519 PUBLIC KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA ED25519 PUBLIC KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PUBLIC KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA ED25519 PUBLIC KEY-----`)
-
-	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, err := UnmarshalEd25519PublicKey(keyBundle)
-	assert.Equal(t, len(k), 32)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-
-	// Fail due to short key
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key")
-
-	// Fail due to invalid banner
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner")
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalX25519PublicKey(t *testing.T) {
-	pubKey := []byte(`# A good key
------BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA X25519 PUBLIC KEY-----
-`)
-	pubP256Key := []byte(`# A good key
------BEGIN NEBULA P256 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
-AAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA P256 PUBLIC KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA X25519 PUBLIC KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PUBLIC KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA X25519 PUBLIC KEY-----`)
-
-	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalPublicKey(keyBundle)
-	assert.Equal(t, len(k), 32)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Equal(t, len(k), 65)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner")
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-// Ensure that upgrading the protobuf library does not change how certificates
-// are marshalled, since this would break signature verification
-func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
-	before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
-	after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-	assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
-
-	b, err = proto.Marshal(nc.getRawDetails())
-	assert.Nil(t, err)
-	//t.Log("Raw cert size:", len(b))
-	assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
-}
-
-func TestNebulaCertificate_Copy(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	cc := c.Copy()
-
-	test.AssertDeepCopyEqual(t, c, cc)
-}
-
-func TestUnmarshalNebulaCertificate(t *testing.T) {
-	// Test that we don't panic with an invalid certificate (#332)
-	data := []byte("\x98\x00\x00")
-	_, err := UnmarshalNebulaCertificate(data)
-	assert.EqualError(t, err, "encoded Details was nil")
-}
-
-func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(Curve_CURVE25519, priv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, priv, nil
-}
-
-func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			Curve:          Curve_P256,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(Curve_P256, rawPriv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, rawPriv, nil
-}
-
-func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	if len(groups) == 0 {
-		groups = []string{"test-group1", "test-group2", "test-group3"}
-	}
-
-	if len(ips) == 0 {
-		ips = []*net.IPNet{
-			{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
-			{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
-			{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
-		}
-	}
-
-	if len(subnets) == 0 {
-		subnets = []*net.IPNet{
-			{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
-			{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
-			{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
-		}
-	}
-
-	var pub, rawPriv []byte
-
-	switch ca.Details.Curve {
-	case Curve_CURVE25519:
-		pub, rawPriv = x25519Keypair()
-	case Curve_P256:
-		pub, rawPriv = p256Keypair()
-	default:
-		return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "testing",
-			Ips:            ips,
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Curve:          ca.Details.Curve,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	return nc, pub, rawPriv, nil
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
-func p256Keypair() ([]byte, []byte) {
-	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
-	if err != nil {
-		panic(err)
-	}
-	pubkey := privkey.PublicKey()
-	return pubkey.Bytes(), privkey.Bytes()
-}

+ 489 - 0
cert/cert_v1.go

@@ -0,0 +1,489 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/binary"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"google.golang.org/protobuf/proto"
+)
+
+const publicKeyLen = 32
+
+type certificateV1 struct {
+	details   detailsV1
+	signature []byte
+}
+
+type detailsV1 struct {
+	name           string
+	networks       []netip.Prefix
+	unsafeNetworks []netip.Prefix
+	groups         []string
+	notBefore      time.Time
+	notAfter       time.Time
+	publicKey      []byte
+	isCA           bool
+	issuer         string
+
+	curve Curve
+}
+
+type m map[string]interface{}
+
+func (c *certificateV1) Version() Version {
+	return Version1
+}
+
+func (c *certificateV1) Curve() Curve {
+	return c.details.curve
+}
+
+func (c *certificateV1) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV1) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV1) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV1) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV1) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV1) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV1) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV1) PublicKey() []byte {
+	return c.details.publicKey
+}
+
+func (c *certificateV1) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV1) Fingerprint() (string, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return "", err
+	}
+
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV1) CheckSignature(key []byte) bool {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return false
+	}
+	switch c.details.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV1) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.details.curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+			}
+
+			if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return fmt.Errorf("cannot parse private key as P256: %w", err)
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.details.publicKey) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return err
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return err
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.details.publicKey) {
+		return fmt.Errorf("public key in cert and private key supplied don't match")
+	}
+
+	return nil
+}
+
+// getRawDetails marshals the raw details into protobuf ready struct
+func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
+	rd := &RawNebulaCertificateDetails{
+		Name:      c.details.name,
+		Groups:    c.details.groups,
+		NotBefore: c.details.notBefore.Unix(),
+		NotAfter:  c.details.notAfter.Unix(),
+		PublicKey: make([]byte, len(c.details.publicKey)),
+		IsCA:      c.details.isCA,
+		Curve:     c.details.curve,
+	}
+
+	for _, ipNet := range c.details.networks {
+		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
+		rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
+	}
+
+	for _, ipNet := range c.details.unsafeNetworks {
+		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
+		rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
+	}
+
+	copy(rd.PublicKey, c.details.publicKey[:])
+
+	// I know, this is terrible
+	rd.Issuer, _ = hex.DecodeString(c.details.issuer)
+
+	return rd
+}
+
+func (c *certificateV1) String() string {
+	b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
+	pubKey := c.details.publicKey
+	c.details.publicKey = nil
+	rawCertNoKey, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	c.details.publicKey = pubKey
+	return rawCertNoKey, nil
+}
+
+func (c *certificateV1) Marshal() ([]byte, error) {
+	rc := RawNebulaCertificate{
+		Details:   c.getRawDetails(),
+		Signature: c.signature,
+	}
+
+	return proto.Marshal(&rc)
+}
+
+func (c *certificateV1) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
+}
+
+func (c *certificateV1) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.marshalJSON())
+}
+
+func (c *certificateV1) marshalJSON() m {
+	fp, _ := c.Fingerprint()
+	return m{
+		"version": Version1,
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"publicKey":      fmt.Sprintf("%x", c.details.publicKey),
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+			"curve":          c.details.curve.String(),
+		},
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}
+}
+
+func (c *certificateV1) Copy() Certificate {
+	nc := &certificateV1{
+		details: detailsV1{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			publicKey: make([]byte, len(c.details.publicKey)),
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+			curve:     c.details.curve,
+		},
+		signature: make([]byte, len(c.signature)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.signature, c.signature)
+	copy(nc.details.publicKey, c.details.publicKey)
+
+	return nc
+}
+
+func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV1{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		publicKey:      t.PublicKey,
+		isCA:           t.IsCA,
+		curve:          t.Curve,
+		issuer:         t.issuer,
+	}
+
+	return c.validate()
+}
+
+func (c *certificateV1) validate() error {
+	// Empty names are allowed
+
+	if len(c.details.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
+	// Continue to allow this behavior
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
+	}
+
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+	}
+
+	// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
+	// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
+	// unsafe networks would result in a different signature.
+
+	return nil
+}
+
+func (c *certificateV1) marshalForSigning() ([]byte, error) {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return nil, err
+	}
+	return b, nil
+}
+
+func (c *certificateV1) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
+// if the publicKey is provided here then it is not required to be present in `b`
+func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rc RawNebulaCertificate
+	err := proto.Unmarshal(b, &rc)
+	if err != nil {
+		return nil, err
+	}
+
+	if rc.Details == nil {
+		return nil, fmt.Errorf("encoded Details was nil")
+	}
+
+	if len(rc.Details.Ips)%2 != 0 {
+		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
+	}
+
+	if len(rc.Details.Subnets)%2 != 0 {
+		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
+	}
+
+	nc := certificateV1{
+		details: detailsV1{
+			name:           rc.Details.Name,
+			groups:         make([]string, len(rc.Details.Groups)),
+			networks:       make([]netip.Prefix, len(rc.Details.Ips)/2),
+			unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
+			notBefore:      time.Unix(rc.Details.NotBefore, 0),
+			notAfter:       time.Unix(rc.Details.NotAfter, 0),
+			publicKey:      make([]byte, len(rc.Details.PublicKey)),
+			isCA:           rc.Details.IsCA,
+			curve:          rc.Details.Curve,
+		},
+		signature: make([]byte, len(rc.Signature)),
+	}
+
+	copy(nc.signature, rc.Signature)
+	copy(nc.details.groups, rc.Details.Groups)
+	nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
+
+	if len(publicKey) > 0 {
+		nc.details.publicKey = publicKey
+	}
+
+	copy(nc.details.publicKey, rc.Details.PublicKey)
+
+	var ip netip.Addr
+	for i, rawIp := range rc.Details.Ips {
+		if i%2 == 0 {
+			ip = int2addr(rawIp)
+		} else {
+			ones, _ := net.IPMask(int2ip(rawIp)).Size()
+			nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
+		}
+	}
+
+	for i, rawIp := range rc.Details.Subnets {
+		if i%2 == 0 {
+			ip = int2addr(rawIp)
+		} else {
+			ones, _ := net.IPMask(int2ip(rawIp)).Size()
+			nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
+		}
+	}
+
+	err = nc.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return &nc, nil
+}
+
+func ip2int(ip []byte) uint32 {
+	if len(ip) == 16 {
+		return binary.BigEndian.Uint32(ip[12:16])
+	}
+	return binary.BigEndian.Uint32(ip)
+}
+
+func int2ip(nn uint32) net.IP {
+	ip := make(net.IP, net.IPv4len)
+	binary.BigEndian.PutUint32(ip, nn)
+	return ip
+}
+
+func addr2int(addr netip.Addr) uint32 {
+	b := addr.Unmap().As4()
+	return binary.BigEndian.Uint32(b[:])
+}
+
+func int2addr(nn uint32) netip.Addr {
+	ip := [4]byte{}
+	binary.BigEndian.PutUint32(ip[:], nn)
+	return netip.AddrFrom4(ip).Unmap()
+}

+ 111 - 111
cert/cert.pb.go → cert/cert_v1.pb.go

@@ -1,8 +1,8 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
 // versions:
-// 	protoc-gen-go v1.30.0
+// 	protoc-gen-go v1.34.2
 // 	protoc        v3.21.5
 // 	protoc        v3.21.5
-// source: cert.proto
+// source: cert_v1.proto
 
 
 package cert
 package cert
 
 
@@ -50,11 +50,11 @@ func (x Curve) String() string {
 }
 }
 
 
 func (Curve) Descriptor() protoreflect.EnumDescriptor {
 func (Curve) Descriptor() protoreflect.EnumDescriptor {
-	return file_cert_proto_enumTypes[0].Descriptor()
+	return file_cert_v1_proto_enumTypes[0].Descriptor()
 }
 }
 
 
 func (Curve) Type() protoreflect.EnumType {
 func (Curve) Type() protoreflect.EnumType {
-	return &file_cert_proto_enumTypes[0]
+	return &file_cert_v1_proto_enumTypes[0]
 }
 }
 
 
 func (x Curve) Number() protoreflect.EnumNumber {
 func (x Curve) Number() protoreflect.EnumNumber {
@@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber {
 
 
 // Deprecated: Use Curve.Descriptor instead.
 // Deprecated: Use Curve.Descriptor instead.
 func (Curve) EnumDescriptor() ([]byte, []int) {
 func (Curve) EnumDescriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{0}
+	return file_cert_v1_proto_rawDescGZIP(), []int{0}
 }
 }
 
 
 type RawNebulaCertificate struct {
 type RawNebulaCertificate struct {
@@ -78,7 +78,7 @@ type RawNebulaCertificate struct {
 func (x *RawNebulaCertificate) Reset() {
 func (x *RawNebulaCertificate) Reset() {
 	*x = RawNebulaCertificate{}
 	*x = RawNebulaCertificate{}
 	if protoimpl.UnsafeEnabled {
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[0]
+		mi := &file_cert_v1_proto_msgTypes[0]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 		ms.StoreMessageInfo(mi)
 	}
 	}
@@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string {
 func (*RawNebulaCertificate) ProtoMessage() {}
 func (*RawNebulaCertificate) ProtoMessage() {}
 
 
 func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
 func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[0]
+	mi := &file_cert_v1_proto_msgTypes[0]
 	if protoimpl.UnsafeEnabled && x != nil {
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
 		if ms.LoadMessageInfo() == nil {
@@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
 
 
 // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
 // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
 func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
 func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{0}
+	return file_cert_v1_proto_rawDescGZIP(), []int{0}
 }
 }
 
 
 func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
 func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
@@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct {
 func (x *RawNebulaCertificateDetails) Reset() {
 func (x *RawNebulaCertificateDetails) Reset() {
 	*x = RawNebulaCertificateDetails{}
 	*x = RawNebulaCertificateDetails{}
 	if protoimpl.UnsafeEnabled {
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[1]
+		mi := &file_cert_v1_proto_msgTypes[1]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 		ms.StoreMessageInfo(mi)
 	}
 	}
@@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string {
 func (*RawNebulaCertificateDetails) ProtoMessage() {}
 func (*RawNebulaCertificateDetails) ProtoMessage() {}
 
 
 func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
 func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[1]
+	mi := &file_cert_v1_proto_msgTypes[1]
 	if protoimpl.UnsafeEnabled && x != nil {
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
 		if ms.LoadMessageInfo() == nil {
@@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
 
 
 // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
 // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
 func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
 func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{1}
+	return file_cert_v1_proto_rawDescGZIP(), []int{1}
 }
 }
 
 
 func (x *RawNebulaCertificateDetails) GetName() string {
 func (x *RawNebulaCertificateDetails) GetName() string {
@@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct {
 func (x *RawNebulaEncryptedData) Reset() {
 func (x *RawNebulaEncryptedData) Reset() {
 	*x = RawNebulaEncryptedData{}
 	*x = RawNebulaEncryptedData{}
 	if protoimpl.UnsafeEnabled {
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[2]
+		mi := &file_cert_v1_proto_msgTypes[2]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 		ms.StoreMessageInfo(mi)
 	}
 	}
@@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string {
 func (*RawNebulaEncryptedData) ProtoMessage() {}
 func (*RawNebulaEncryptedData) ProtoMessage() {}
 
 
 func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
 func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[2]
+	mi := &file_cert_v1_proto_msgTypes[2]
 	if protoimpl.UnsafeEnabled && x != nil {
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
 		if ms.LoadMessageInfo() == nil {
@@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
 
 
 // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
 // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
 func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
 func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{2}
+	return file_cert_v1_proto_rawDescGZIP(), []int{2}
 }
 }
 
 
 func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
 func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
@@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct {
 func (x *RawNebulaEncryptionMetadata) Reset() {
 func (x *RawNebulaEncryptionMetadata) Reset() {
 	*x = RawNebulaEncryptionMetadata{}
 	*x = RawNebulaEncryptionMetadata{}
 	if protoimpl.UnsafeEnabled {
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[3]
+		mi := &file_cert_v1_proto_msgTypes[3]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 		ms.StoreMessageInfo(mi)
 	}
 	}
@@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string {
 func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
 func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
 
 
 func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
 func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[3]
+	mi := &file_cert_v1_proto_msgTypes[3]
 	if protoimpl.UnsafeEnabled && x != nil {
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
 		if ms.LoadMessageInfo() == nil {
@@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
 
 
 // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
 // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
 func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
 func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{3}
+	return file_cert_v1_proto_rawDescGZIP(), []int{3}
 }
 }
 
 
 func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
 func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
@@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct {
 func (x *RawNebulaArgon2Parameters) Reset() {
 func (x *RawNebulaArgon2Parameters) Reset() {
 	*x = RawNebulaArgon2Parameters{}
 	*x = RawNebulaArgon2Parameters{}
 	if protoimpl.UnsafeEnabled {
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[4]
+		mi := &file_cert_v1_proto_msgTypes[4]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 		ms.StoreMessageInfo(mi)
 	}
 	}
@@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string {
 func (*RawNebulaArgon2Parameters) ProtoMessage() {}
 func (*RawNebulaArgon2Parameters) ProtoMessage() {}
 
 
 func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
 func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[4]
+	mi := &file_cert_v1_proto_msgTypes[4]
 	if protoimpl.UnsafeEnabled && x != nil {
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
 		if ms.LoadMessageInfo() == nil {
@@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
 
 
 // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
 // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
 func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
 func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{4}
+	return file_cert_v1_proto_rawDescGZIP(), []int{4}
 }
 }
 
 
 func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
 func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
@@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte {
 	return nil
 	return nil
 }
 }
 
 
-var File_cert_proto protoreflect.FileDescriptor
-
-var file_cert_proto_rawDesc = []byte{
-	0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65,
-	0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
-	0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65,
-	0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74,
-	0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
-	0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
-	0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
-	0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
-	0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
-	0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
-	0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
-	0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53,
-	0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75,
-	0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
-	0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a,
-	0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03,
-	0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e,
-	0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e,
-	0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69,
-	0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c,
-	0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
-	0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
-	0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
-	0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e,
-	0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63,
-	0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
-	0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12,
-	0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
-	0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72,
-	0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12,
-	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74,
-	0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65,
-	0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61,
-	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
-	0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
-	0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72,
-	0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61,
-	0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f,
-	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
-	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52,
-	0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72,
-	0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
-	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12,
-	0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05,
-	0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d,
-	0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72,
-	0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
-	0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
-	0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
-	0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69,
-	0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
-	0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65,
-	0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00,
-	0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69,
-	0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71,
-	0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72,
-	0x6f, 0x74, 0x6f, 0x33,
+var File_cert_v1_proto protoreflect.FileDescriptor
+
+var file_cert_v1_proto_rawDesc = []byte{
+	0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
+	0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a,
+	0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
+	0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c,
+	0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69,
+	0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53,
+	0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77,
+	0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74,
+	0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
+	0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03,
+	0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18,
+	0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52,
+	0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75,
+	0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73,
+	0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20,
+	0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a,
+	0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03,
+	0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75,
+	0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50,
+	0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41,
+	0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06,
+	0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73,
+	0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20,
+	0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65,
+	0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e,
+	0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61,
+	0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
+	0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
+	0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
+	0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74,
+	0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65,
+	0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
+	0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
+	0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01,
+	0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c,
+	0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e,
+	0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28,
+	0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
+	0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65,
+	0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
+	0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
+	0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06,
+	0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65,
+	0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
+	0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c,
+	0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74,
+	0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72,
+	0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05,
+	0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75,
+	0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31,
+	0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a,
+	0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63,
+	0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62,
+	0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 }
 
 
 var (
 var (
-	file_cert_proto_rawDescOnce sync.Once
-	file_cert_proto_rawDescData = file_cert_proto_rawDesc
+	file_cert_v1_proto_rawDescOnce sync.Once
+	file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc
 )
 )
 
 
-func file_cert_proto_rawDescGZIP() []byte {
-	file_cert_proto_rawDescOnce.Do(func() {
-		file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData)
+func file_cert_v1_proto_rawDescGZIP() []byte {
+	file_cert_v1_proto_rawDescOnce.Do(func() {
+		file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData)
 	})
 	})
-	return file_cert_proto_rawDescData
+	return file_cert_v1_proto_rawDescData
 }
 }
 
 
-var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
-var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
-var file_cert_proto_goTypes = []interface{}{
+var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
+var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
+var file_cert_v1_proto_goTypes = []any{
 	(Curve)(0),                          // 0: cert.Curve
 	(Curve)(0),                          // 0: cert.Curve
 	(*RawNebulaCertificate)(nil),        // 1: cert.RawNebulaCertificate
 	(*RawNebulaCertificate)(nil),        // 1: cert.RawNebulaCertificate
 	(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
 	(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
@@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{
 	(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
 	(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
 	(*RawNebulaArgon2Parameters)(nil),   // 5: cert.RawNebulaArgon2Parameters
 	(*RawNebulaArgon2Parameters)(nil),   // 5: cert.RawNebulaArgon2Parameters
 }
 }
-var file_cert_proto_depIdxs = []int32{
+var file_cert_v1_proto_depIdxs = []int32{
 	2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
 	2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
 	0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
 	0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
 	4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
 	4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
@@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{
 	0, // [0:4] is the sub-list for field type_name
 	0, // [0:4] is the sub-list for field type_name
 }
 }
 
 
-func init() { file_cert_proto_init() }
-func file_cert_proto_init() {
-	if File_cert_proto != nil {
+func init() { file_cert_v1_proto_init() }
+func file_cert_v1_proto_init() {
+	if File_cert_v1_proto != nil {
 		return
 		return
 	}
 	}
 	if !protoimpl.UnsafeEnabled {
 	if !protoimpl.UnsafeEnabled {
-		file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaCertificate); i {
 			switch v := v.(*RawNebulaCertificate); i {
 			case 0:
 			case 0:
 				return &v.state
 				return &v.state
@@ -549,7 +549,7 @@ func file_cert_proto_init() {
 				return nil
 				return nil
 			}
 			}
 		}
 		}
-		file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaCertificateDetails); i {
 			switch v := v.(*RawNebulaCertificateDetails); i {
 			case 0:
 			case 0:
 				return &v.state
 				return &v.state
@@ -561,7 +561,7 @@ func file_cert_proto_init() {
 				return nil
 				return nil
 			}
 			}
 		}
 		}
-		file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaEncryptedData); i {
 			switch v := v.(*RawNebulaEncryptedData); i {
 			case 0:
 			case 0:
 				return &v.state
 				return &v.state
@@ -573,7 +573,7 @@ func file_cert_proto_init() {
 				return nil
 				return nil
 			}
 			}
 		}
 		}
-		file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaEncryptionMetadata); i {
 			switch v := v.(*RawNebulaEncryptionMetadata); i {
 			case 0:
 			case 0:
 				return &v.state
 				return &v.state
@@ -585,7 +585,7 @@ func file_cert_proto_init() {
 				return nil
 				return nil
 			}
 			}
 		}
 		}
-		file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaArgon2Parameters); i {
 			switch v := v.(*RawNebulaArgon2Parameters); i {
 			case 0:
 			case 0:
 				return &v.state
 				return &v.state
@@ -602,19 +602,19 @@ func file_cert_proto_init() {
 	out := protoimpl.TypeBuilder{
 	out := protoimpl.TypeBuilder{
 		File: protoimpl.DescBuilder{
 		File: protoimpl.DescBuilder{
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
-			RawDescriptor: file_cert_proto_rawDesc,
+			RawDescriptor: file_cert_v1_proto_rawDesc,
 			NumEnums:      1,
 			NumEnums:      1,
 			NumMessages:   5,
 			NumMessages:   5,
 			NumExtensions: 0,
 			NumExtensions: 0,
 			NumServices:   0,
 			NumServices:   0,
 		},
 		},
-		GoTypes:           file_cert_proto_goTypes,
-		DependencyIndexes: file_cert_proto_depIdxs,
-		EnumInfos:         file_cert_proto_enumTypes,
-		MessageInfos:      file_cert_proto_msgTypes,
+		GoTypes:           file_cert_v1_proto_goTypes,
+		DependencyIndexes: file_cert_v1_proto_depIdxs,
+		EnumInfos:         file_cert_v1_proto_enumTypes,
+		MessageInfos:      file_cert_v1_proto_msgTypes,
 	}.Build()
 	}.Build()
-	File_cert_proto = out.File
-	file_cert_proto_rawDesc = nil
-	file_cert_proto_goTypes = nil
-	file_cert_proto_depIdxs = nil
+	File_cert_v1_proto = out.File
+	file_cert_v1_proto_rawDesc = nil
+	file_cert_v1_proto_goTypes = nil
+	file_cert_v1_proto_depIdxs = nil
 }
 }

+ 0 - 0
cert/cert.proto → cert/cert_v1.proto


+ 218 - 0
cert/cert_v1_test.go

@@ -0,0 +1,218 @@
+package cert
+
+import (
+	"fmt"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"google.golang.org/protobuf/proto"
+)
+
+func TestCertificateV1_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	assert.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+
+	assert.Equal(t, nc.Version(), Version1)
+	assert.Equal(t, nc.Curve(), Curve_CURVE25519)
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV1_Expired(t *testing.T) {
+	nc := certificateV1{
+		details: detailsV1{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV1_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
+		string(b),
+	)
+}
+
+func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.NotNil(t, err)
+}
+
+// Ensure that upgrading the protobuf library does not change how certificates
+// are marshalled, since this would break signature verification
+func TestMarshalingCertificateV1Consistency(t *testing.T) {
+	before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
+	after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	require.Nil(t, err)
+	assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
+
+	b, err = proto.Marshal(nc.getRawDetails())
+	assert.Nil(t, err)
+	assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
+}
+
+func TestCertificateV1_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV1(t *testing.T) {
+	// Test that we don't panic with an invalid certificate (#332)
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV1(data, nil)
+	assert.EqualError(t, err, "encoded Details was nil")
+}
+
+func appendByteSlices(b ...[]byte) []byte {
+	retSlice := []byte{}
+	for _, v := range b {
+		retSlice = append(retSlice, v...)
+	}
+	return retSlice
+}
+
+func mustParsePrefixUnmapped(s string) netip.Prefix {
+	prefix := netip.MustParsePrefix(s)
+	return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
+}

+ 37 - 0
cert/cert_v2.asn1

@@ -0,0 +1,37 @@
+Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
+
+Name ::= UTF8String (SIZE (1..253))
+Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
+Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
+Curve ::= ENUMERATED {
+    curve25519 (0),
+    p256 (1)
+}
+
+-- The maximum size of a certificate must not exceed 65536 bytes
+Certificate ::= SEQUENCE {
+    details OCTET STRING,
+    curve Curve DEFAULT curve25519,
+    publicKey OCTET STRING,
+    -- signature(details + curve + publicKey) using the appropriate method for curve
+    signature OCTET STRING
+}
+
+Details ::= SEQUENCE {
+    name Name,
+
+    -- At least 1 ipv4 or ipv6 address must be present if isCA is false
+    networks SEQUENCE OF Network OPTIONAL,
+    unsafeNetworks SEQUENCE OF Network OPTIONAL,
+    groups SEQUENCE OF Name OPTIONAL,
+    isCA BOOLEAN DEFAULT false,
+    notBefore Time,
+    notAfter Time,
+
+    -- issuer is only required if isCA is false, if isCA is true then it must not be present
+    issuer OCTET STRING OPTIONAL,
+    ...
+    -- New fields can be added below here
+}
+
+END

+ 730 - 0
cert/cert_v2.go

@@ -0,0 +1,730 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net/netip"
+	"slices"
+	"time"
+
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+	"golang.org/x/crypto/curve25519"
+)
+
+const (
+	classConstructed     = 0x20
+	classContextSpecific = 0x80
+
+	TagCertDetails   = 0 | classConstructed | classContextSpecific
+	TagCertCurve     = 1 | classContextSpecific
+	TagCertPublicKey = 2 | classContextSpecific
+	TagCertSignature = 3 | classContextSpecific
+
+	TagDetailsName           = 0 | classContextSpecific
+	TagDetailsNetworks       = 1 | classConstructed | classContextSpecific
+	TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
+	TagDetailsGroups         = 3 | classConstructed | classContextSpecific
+	TagDetailsIsCA           = 4 | classContextSpecific
+	TagDetailsNotBefore      = 5 | classContextSpecific
+	TagDetailsNotAfter       = 6 | classContextSpecific
+	TagDetailsIssuer         = 7 | classContextSpecific
+)
+
+const (
+	// MaxCertificateSize is the maximum length a valid certificate can be
+	MaxCertificateSize = 65536
+
+	// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
+	MaxNameLength = 253
+
+	// MaxNetworkLength is the maximum length a network value can be.
+	// 16 bytes for an ipv6 address + 1 byte for the prefix length
+	MaxNetworkLength = 17
+)
+
+type certificateV2 struct {
+	details detailsV2
+
+	// RawDetails contains the entire asn.1 DER encoded Details struct
+	// This is to benefit forwards compatibility in signature checking.
+	// signature(RawDetails + Curve + PublicKey) == Signature
+	rawDetails []byte
+	curve      Curve
+	publicKey  []byte
+	signature  []byte
+}
+
+type detailsV2 struct {
+	name           string
+	networks       []netip.Prefix // MUST BE SORTED
+	unsafeNetworks []netip.Prefix // MUST BE SORTED
+	groups         []string
+	isCA           bool
+	notBefore      time.Time
+	notAfter       time.Time
+	issuer         string
+}
+
+func (c *certificateV2) Version() Version {
+	return Version2
+}
+
+func (c *certificateV2) Curve() Curve {
+	return c.curve
+}
+
+func (c *certificateV2) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV2) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV2) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV2) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV2) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV2) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV2) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV2) PublicKey() []byte {
+	return c.publicKey
+}
+
+func (c *certificateV2) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV2) Fingerprint() (string, error) {
+	if len(c.rawDetails) == 0 {
+		return "", ErrMissingDetails
+	}
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV2) CheckSignature(key []byte) bool {
+	if len(c.rawDetails) == 0 {
+		return false
+	}
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+
+	switch c.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV2) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.curve {
+		return ErrPublicPrivateCurveMismatch
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return ErrInvalidPrivateKey
+			}
+
+			if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return ErrInvalidPrivateKey
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.publicKey) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.publicKey) {
+		return ErrPublicPrivateKeyMismatch
+	}
+
+	return nil
+}
+
+func (c *certificateV2) String() string {
+	mb, err := c.marshalJSON()
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+
+	b, err := json.MarshalIndent(mb, "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Skipping the curve and public key since those come across in a different part of the handshake
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) Marshal() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Add the curve only if its not the default value
+		if c.curve != Curve_CURVE25519 {
+			b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
+				b.AddBytes([]byte{byte(c.curve)})
+			})
+		}
+
+		// Add the public key if it is not empty
+		if c.publicKey != nil {
+			b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
+				b.AddBytes(c.publicKey)
+			})
+		}
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
+}
+
+func (c *certificateV2) MarshalJSON() ([]byte, error) {
+	b, err := c.marshalJSON()
+	if err != nil {
+		return nil, err
+	}
+	return json.Marshal(b)
+}
+
+func (c *certificateV2) marshalJSON() (m, error) {
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, err
+	}
+
+	return m{
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+		},
+		"version":     Version2,
+		"publicKey":   fmt.Sprintf("%x", c.publicKey),
+		"curve":       c.curve.String(),
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}, nil
+}
+
+func (c *certificateV2) Copy() Certificate {
+	nc := &certificateV2{
+		details: detailsV2{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+		},
+		curve:      c.curve,
+		publicKey:  make([]byte, len(c.publicKey)),
+		signature:  make([]byte, len(c.signature)),
+		rawDetails: make([]byte, len(c.rawDetails)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.rawDetails, c.rawDetails)
+	copy(nc.signature, c.signature)
+	copy(nc.publicKey, c.publicKey)
+
+	return nc
+}
+
+func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV2{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		isCA:           t.IsCA,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		issuer:         t.issuer,
+	}
+	c.curve = t.Curve
+	c.publicKey = t.PublicKey
+	return c.validate()
+}
+
+func (c *certificateV2) validate() error {
+	// Empty names are allowed
+
+	if len(c.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
+	}
+
+	hasV4Networks := false
+	hasV6Networks := false
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+
+		if network.Addr().Is4In6() {
+			return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
+		}
+
+		hasV4Networks = hasV4Networks || network.Addr().Is4()
+		hasV6Networks = hasV6Networks || network.Addr().Is6()
+	}
+
+	slices.SortFunc(c.details.networks, comparePrefix)
+	err := findDuplicatePrefix(c.details.networks)
+	if err != nil {
+		return err
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+
+		if !c.details.isCA {
+			if network.Addr().Is6() {
+				if !hasV6Networks {
+					return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
+				}
+			} else if network.Addr().Is4() {
+				if !hasV4Networks {
+					return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
+				}
+			}
+		}
+	}
+
+	slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
+	err = findDuplicatePrefix(c.details.unsafeNetworks)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *certificateV2) marshalForSigning() ([]byte, error) {
+	d, err := c.details.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
+	}
+	c.rawDetails = d
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	return b, nil
+}
+
+func (c *certificateV2) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+func (d *detailsV2) Marshal() ([]byte, error) {
+	var b cryptobyte.Builder
+	var err error
+
+	// Details are a structure
+	b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
+
+		// Add the name
+		b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
+			b.AddBytes([]byte(d.name))
+		})
+
+		// Add the networks if any exist
+		if len(d.networks) > 0 {
+			b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.networks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add the unsafe networks if any exist
+		if len(d.unsafeNetworks) > 0 {
+			b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.unsafeNetworks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add groups if any exist
+		if len(d.groups) > 0 {
+			b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
+				for _, group := range d.groups {
+					b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
+						b.AddBytes([]byte(group))
+					})
+				}
+			})
+		}
+
+		// Add IsCA only if true
+		if d.isCA {
+			b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
+				b.AddUint8(0xff)
+			})
+		}
+
+		// Add not before
+		b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
+
+		// Add not after
+		b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
+
+		// Add the issuer if present
+		if d.issuer != "" {
+			issuerBytes, innerErr := hex.DecodeString(d.issuer)
+			if innerErr != nil {
+				err = fmt.Errorf("failed to decode issuer: %w", innerErr)
+				return
+			}
+			b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
+				b.AddBytes(issuerBytes)
+			})
+		}
+	})
+
+	if err != nil {
+		return nil, err
+	}
+
+	return b.Bytes()
+}
+
+func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
+	l := len(b)
+	if l == 0 || l > MaxCertificateSize {
+		return nil, ErrBadFormat
+	}
+
+	input := cryptobyte.String(b)
+	// Open the envelope
+	if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the cert details, we need to preserve the tag and length
+	var rawDetails cryptobyte.String
+	if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	//Maybe grab the curve
+	var rawCurve byte
+	if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
+		return nil, ErrBadFormat
+	}
+	curve = Curve(rawCurve)
+
+	// Maybe grab the public key
+	var rawPublicKey cryptobyte.String
+	if len(publicKey) > 0 {
+		rawPublicKey = publicKey
+	} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
+		return nil, ErrBadFormat
+	}
+
+	if len(rawPublicKey) == 0 {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the signature
+	var rawSignature cryptobyte.String
+	if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Finally unmarshal the details
+	details, err := unmarshalDetails(rawDetails)
+	if err != nil {
+		return nil, err
+	}
+
+	c := &certificateV2{
+		details:    details,
+		rawDetails: rawDetails,
+		curve:      curve,
+		publicKey:  rawPublicKey,
+		signature:  rawSignature,
+	}
+
+	err = c.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
+}
+
+func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
+	// Open the envelope
+	if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the name
+	var name cryptobyte.String
+	if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the network addresses
+	var subString cryptobyte.String
+	var found bool
+
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var networks []netip.Prefix
+	var val cryptobyte.String
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			networks = append(networks, n)
+		}
+	}
+
+	// Read out any unsafe networks
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var unsafeNetworks []netip.Prefix
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			unsafeNetworks = append(unsafeNetworks, n)
+		}
+	}
+
+	// Read out any groups
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var groups []string
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
+				return detailsV2{}, ErrBadFormat
+			}
+			groups = append(groups, string(val))
+		}
+	}
+
+	// Read out IsCA
+	var isCa bool
+	if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read not before and not after
+	var notBefore int64
+	if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var notAfter int64
+	if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read issuer
+	var issuer cryptobyte.String
+	if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	return detailsV2{
+		name:           string(name),
+		networks:       networks,
+		unsafeNetworks: unsafeNetworks,
+		groups:         groups,
+		isCA:           isCa,
+		notBefore:      time.Unix(notBefore, 0),
+		notAfter:       time.Unix(notAfter, 0),
+		issuer:         hex.EncodeToString(issuer),
+	}, nil
+}

+ 267 - 0
cert/cert_v2_test.go

@@ -0,0 +1,267 @@
+package cert
+
+import (
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
+	"slices"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCertificateV2_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	nc.rawDetails = db
+
+	b, err := nc.Marshal()
+	require.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
+	assert.Nil(t, err)
+
+	assert.Equal(t, nc.Version(), Version2)
+	assert.Equal(t, nc.Curve(), Curve_CURVE25519)
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+	assert.Equal(t, nc.Issuer(), nc2.Issuer())
+
+	// unmarshalling will sort networks and unsafeNetworks, we need to do the same
+	// but first make sure it fails
+	assert.NotEqual(t, nc.Networks(), nc2.Networks())
+	assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	slices.SortFunc(nc.details.networks, comparePrefix)
+	slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV2_Expired(t *testing.T) {
+	nc := certificateV2{
+		details: detailsV2{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV2_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedf1234567890abcedf")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			isCA:      false,
+			issuer:    "1234567890abcedf1234567890abcedf",
+		},
+		publicKey: pubKey,
+		signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
+	}
+
+	b, err := nc.MarshalJSON()
+	assert.ErrorIs(t, err, ErrMissingDetails)
+
+	rd, err := nc.details.Marshal()
+	assert.NoError(t, err)
+
+	nc.rawDetails = rd
+	b, err = nc.MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
+		string(b),
+	)
+}
+
+func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	_, caKey2, err := ed25519.GenerateKey(rand.Reader)
+	require.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	ac, ok := c.(*certificateV2)
+	require.True(t, ok)
+	ac.curve = Curve(99)
+	err = c.VerifyPrivateKey(Curve(99), priv2)
+	assert.EqualError(t, err, "invalid curve: 99")
+
+	ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv)
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	aCa, ok := ca2.(*certificateV2)
+	require.True(t, ok)
+	aCa.curve = Curve(99)
+	err = aCa.VerifyPrivateKey(Curve(99), priv2)
+	assert.EqualError(t, err, "invalid curve: 99")
+
+}
+
+func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestCertificateV2_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV2(t *testing.T) {
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
+	assert.EqualError(t, err, "bad wire format")
+}
+
+func TestCertificateV2_marshalForSigningStability(t *testing.T) {
+	before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
+	after := before.Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
+	expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
+	require.NoError(t, err)
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	assert.Equal(t, expectedRawDetails, db)
+
+	expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
+	b, err := nc.marshalForSigning()
+	require.NoError(t, err)
+	assert.Equal(t, expectedForSigning, b)
+}

+ 159 - 2
cert/crypto.go

@@ -3,14 +3,28 @@ package cert
 import (
 import (
 	"crypto/aes"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/cipher"
+	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/rand"
+	"encoding/pem"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"math"
 
 
 	"golang.org/x/crypto/argon2"
 	"golang.org/x/crypto/argon2"
+	"google.golang.org/protobuf/proto"
 )
 )
 
 
-// KDF factors
+type NebulaEncryptedData struct {
+	EncryptionMetadata NebulaEncryptionMetadata
+	Ciphertext         []byte
+}
+
+type NebulaEncryptionMetadata struct {
+	EncryptionAlgorithm string
+	Argon2Parameters    Argon2Parameters
+}
+
+// Argon2Parameters KDF factors
 type Argon2Parameters struct {
 type Argon2Parameters struct {
 	version     rune
 	version     rune
 	Memory      uint32 // KiB
 	Memory      uint32 // KiB
@@ -19,7 +33,7 @@ type Argon2Parameters struct {
 	salt        []byte
 	salt        []byte
 }
 }
 
 
-// Returns a new Argon2Parameters object with current version set
+// NewArgon2Parameters Returns a new Argon2Parameters object with current version set
 func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
 func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
 	return &Argon2Parameters{
 	return &Argon2Parameters{
 		version:     argon2.Version,
 		version:     argon2.Version,
@@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) {
 
 
 	return blob[:nonceSize], blob[nonceSize:], nil
 	return blob[:nonceSize], blob[nonceSize:], nil
 }
 }
+
+// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
+func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
+	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
+	if err != nil {
+		return nil, err
+	}
+
+	b, err = proto.Marshal(&RawNebulaEncryptedData{
+		EncryptionMetadata: &RawNebulaEncryptionMetadata{
+			EncryptionAlgorithm: "AES-256-GCM",
+			Argon2Parameters: &RawNebulaArgon2Parameters{
+				Version:     kdfParams.version,
+				Memory:      kdfParams.Memory,
+				Parallelism: uint32(kdfParams.Parallelism),
+				Iterations:  kdfParams.Iterations,
+				Salt:        kdfParams.salt,
+			},
+		},
+		Ciphertext: ciphertext,
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
+	default:
+		return nil, fmt.Errorf("invalid curve: %v", curve)
+	}
+}
+
+// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
+// protobuf-generated struct.
+func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rned RawNebulaEncryptedData
+	err := proto.Unmarshal(b, &rned)
+	if err != nil {
+		return nil, err
+	}
+
+	if rned.EncryptionMetadata == nil {
+		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
+	}
+
+	if rned.EncryptionMetadata.Argon2Parameters == nil {
+		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
+	}
+
+	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
+	if err != nil {
+		return nil, err
+	}
+
+	ned := NebulaEncryptedData{
+		EncryptionMetadata: NebulaEncryptionMetadata{
+			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
+			Argon2Parameters:    *params,
+		},
+		Ciphertext: rned.Ciphertext,
+	}
+
+	return &ned, nil
+}
+
+func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
+	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
+		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
+	}
+	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
+		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
+	}
+	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
+		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
+	}
+	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
+		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
+	}
+
+	return &Argon2Parameters{
+		version:     params.Version,
+		Memory:      params.Memory,
+		Parallelism: uint8(params.Parallelism),
+		Iterations:  params.Iterations,
+		salt:        params.Salt,
+	}, nil
+
+}
+
+// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
+// the given passphrase, returning any other bytes b or an error on failure
+func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
+	var curve Curve
+
+	k, r := pem.Decode(b)
+	if k == nil {
+		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+	case EncryptedECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+	default:
+		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	}
+
+	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
+	if err != nil {
+		return curve, nil, r, err
+	}
+
+	var bytes []byte
+	switch ned.EncryptionMetadata.EncryptionAlgorithm {
+	case "AES-256-GCM":
+		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
+		if err != nil {
+			return curve, nil, r, err
+		}
+	default:
+		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
+	}
+
+	switch curve {
+	case Curve_CURVE25519:
+		if len(bytes) != ed25519.PrivateKeySize {
+			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case Curve_P256:
+		if len(bytes) != 32 {
+			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	}
+
+	return curve, bytes, r, nil
+}

+ 87 - 0
cert/crypto_test.go

@@ -23,3 +23,90 @@ func TestNewArgon2Parameters(t *testing.T) {
 		Iterations:  1,
 		Iterations:  1,
 	}, p)
 	}, p)
 }
 }
+
+func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
+	passphrase := []byte("DO NOT USE THIS KEY")
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A key which, once decrypted, is too short
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
+k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
+GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
+rQr3bdH3Oy/WiYU=
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner (not encrypted)
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
+XgLvodMXZJuaFPssp+WwtA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+
+	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
+	assert.Nil(t, err)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+
+	// Fail due to invalid banner
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to invalid passphrase
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
+	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, []byte{})
+}
+
+func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
+	// Having proved that decryption works correctly above, we can test the
+	// encryption function produces a value which can be decrypted
+	passphrase := []byte("passphrase")
+	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
+	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
+	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
+	assert.Nil(t, err)
+
+	// Verify the "key" can be decrypted successfully
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
+	assert.Len(t, k, 64)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, []byte{})
+	assert.Nil(t, err)
+
+	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
+}

+ 41 - 6
cert/errors.go

@@ -2,13 +2,48 @@ package cert
 
 
 import (
 import (
 	"errors"
 	"errors"
+	"fmt"
 )
 )
 
 
 var (
 var (
-	ErrRootExpired       = errors.New("root certificate is expired")
-	ErrExpired           = errors.New("certificate is expired")
-	ErrNotCA             = errors.New("certificate is not a CA")
-	ErrNotSelfSigned     = errors.New("certificate is not self-signed")
-	ErrBlockListed       = errors.New("certificate is in the block list")
-	ErrSignatureMismatch = errors.New("certificate signature did not match")
+	ErrBadFormat                  = errors.New("bad wire format")
+	ErrRootExpired                = errors.New("root certificate is expired")
+	ErrExpired                    = errors.New("certificate is expired")
+	ErrNotCA                      = errors.New("certificate is not a CA")
+	ErrNotSelfSigned              = errors.New("certificate is not self-signed")
+	ErrBlockListed                = errors.New("certificate is in the block list")
+	ErrFingerprintMismatch        = errors.New("certificate fingerprint did not match")
+	ErrSignatureMismatch          = errors.New("certificate signature did not match")
+	ErrInvalidPublicKey           = errors.New("invalid public key")
+	ErrInvalidPrivateKey          = errors.New("invalid private key")
+	ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
+	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
+	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
+	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
+
+	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
+	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")
+	ErrInvalidPEMX25519PublicKeyBanner   = errors.New("bytes did not contain a proper X25519 public key banner")
+	ErrInvalidPEMX25519PrivateKeyBanner  = errors.New("bytes did not contain a proper X25519 private key banner")
+	ErrInvalidPEMEd25519PublicKeyBanner  = errors.New("bytes did not contain a proper Ed25519 public key banner")
+	ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
+
+	ErrNoPeerStaticKey = errors.New("no peer static key was present")
+	ErrNoPayload       = errors.New("provided payload was empty")
+
+	ErrMissingDetails  = errors.New("certificate did not contain details")
+	ErrEmptySignature  = errors.New("empty signature")
+	ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
 )
 )
+
+type ErrInvalidCertificateProperties struct {
+	str string
+}
+
+func NewErrInvalidCertificateProperties(format string, a ...any) error {
+	return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
+}
+
+func (e *ErrInvalidCertificateProperties) Error() string {
+	return e.str
+}

+ 141 - 0
cert/helper_test.go

@@ -0,0 +1,141 @@
+package cert
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 161 - 0
cert/pem.go

@@ -0,0 +1,161 @@
+package cert
+
+import (
+	"encoding/pem"
+	"fmt"
+
+	"golang.org/x/crypto/ed25519"
+)
+
+const (
+	CertificateBanner                = "NEBULA CERTIFICATE"
+	CertificateV2Banner              = "NEBULA CERTIFICATE V2"
+	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
+	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
+	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
+	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
+	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
+
+	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
+	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
+	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
+	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+)
+
+// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
+// data or an error on failure
+func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
+	p, r := pem.Decode(b)
+	if p == nil {
+		return nil, r, ErrInvalidPEMBlock
+	}
+
+	var c Certificate
+	var err error
+
+	switch p.Type {
+	// Implementations must validate the resulting certificate contains valid information
+	case CertificateBanner:
+		c, err = unmarshalCertificateV1(p.Bytes, nil)
+	case CertificateV2Banner:
+		c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
+	default:
+		return nil, r, ErrInvalidPEMCertificateBanner
+	}
+
+	if err != nil {
+		return nil, r, err
+	}
+
+	return c, r, nil
+
+}
+
+func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PublicKeyBanner:
+		// Uncompressed
+		expectedLen = 65
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
+// consumed data or an error on failure
+func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var curve Curve
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
+	case EncryptedECDSAP256PrivateKeyBanner:
+		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
+	case Ed25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+		if len(k.Bytes) != ed25519.PrivateKeySize {
+			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case ECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+		if len(k.Bytes) != 32 {
+			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner")
+	}
+	return k.Bytes, r, curve, nil
+}

+ 292 - 0
cert/pem_test.go

@@ -0,0 +1,292 @@
+package cert
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestUnmarshalCertificateFromPEM(t *testing.T) {
+	goodCert := []byte(`
+# A good cert
+-----BEGIN NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-----END NEBULA CERTIFICATE-----
+`)
+	badBanner := []byte(`# A bad banner
+-----BEGIN NOT A NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-----END NOT A NEBULA CERTIFICATE-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-END NEBULA CERTIFICATE----`)
+
+	certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
+
+	// Success test case
+	cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
+	assert.NotNil(t, cert)
+	assert.Equal(t, rest, append(badBanner, invalidPem...))
+	assert.Nil(t, err)
+
+	// Fail due to invalid banner.
+	cert, rest, err = UnmarshalCertificateFromPEM(rest)
+	assert.Nil(t, cert)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "bytes did not contain a proper certificate banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	cert, rest, err = UnmarshalCertificateFromPEM(rest)
+	assert.Nil(t, cert)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ECDSA P256 PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NOT A NEBULA PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-END NEBULA ED25519 PRIVATE KEY-----`)
+
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+	assert.Nil(t, err)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA X25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA X25519 PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA X25519 PRIVATE KEY-----`)
+
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+	assert.Nil(t, err)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "bytes did not contain a proper private key banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
+	pubKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ED25519 PUBLIC KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA ED25519 PUBLIC KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PUBLIC KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA ED25519 PUBLIC KEY-----`)
+
+	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
+	assert.Equal(t, 32, len(k))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.EqualError(t, err, "bytes did not contain a proper public key banner")
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalX25519PublicKey(t *testing.T) {
+	pubKey := []byte(`# A good key
+-----BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA X25519 PUBLIC KEY-----
+`)
+	pubP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PUBLIC KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA X25519 PUBLIC KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PUBLIC KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA X25519 PUBLIC KEY-----`)
+
+	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
+	assert.Equal(t, 32, len(k))
+	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Equal(t, 65, len(k))
+	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.EqualError(t, err, "bytes did not contain a proper public key banner")
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}

+ 167 - 0
cert/sign.go

@@ -0,0 +1,167 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/sha256"
+	"fmt"
+	"math/big"
+	"net/netip"
+	"time"
+)
+
+// TBSCertificate represents a certificate intended to be signed.
+// It is invalid to use this structure as a Certificate.
+type TBSCertificate struct {
+	Version        Version
+	Name           string
+	Networks       []netip.Prefix
+	UnsafeNetworks []netip.Prefix
+	Groups         []string
+	IsCA           bool
+	NotBefore      time.Time
+	NotAfter       time.Time
+	PublicKey      []byte
+	Curve          Curve
+	issuer         string
+}
+
+type beingSignedCertificate interface {
+	// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
+	// Implementations must validate the resulting certificate contains valid information
+	fromTBSCertificate(*TBSCertificate) error
+
+	// marshalForSigning returns the bytes that should be signed
+	marshalForSigning() ([]byte, error)
+
+	// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
+	setSignature([]byte) error
+}
+
+type SignerLambda func(certBytes []byte) ([]byte, error)
+
+// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
+// details do not violate constraints of the signing certificate.
+// If the TBSCertificate is a CA then signer must be nil.
+func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
+	switch t.Curve {
+	case Curve_CURVE25519:
+		pk := ed25519.PrivateKey(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			sig := ed25519.Sign(pk, certBytes)
+			return sig, nil
+		}
+		return t.SignWith(signer, curve, sp)
+	case Curve_P256:
+		pk := &ecdsa.PrivateKey{
+			PublicKey: ecdsa.PublicKey{
+				Curve: elliptic.P256(),
+			},
+			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
+			D: new(big.Int).SetBytes(key),
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
+		pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			// We need to hash first for ECDSA
+			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
+			hashed := sha256.Sum256(certBytes)
+			return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
+		}
+		return t.SignWith(signer, curve, sp)
+	default:
+		return nil, fmt.Errorf("invalid curve: %s", t.Curve)
+	}
+}
+
+// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
+// You should only use SignWith if you do not have direct access to your private key.
+func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
+	if curve != t.Curve {
+		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+
+	if signer != nil {
+		if t.IsCA {
+			return nil, fmt.Errorf("can not sign a CA certificate with another")
+		}
+
+		err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks)
+		if err != nil {
+			return nil, err
+		}
+
+		issuer, err := signer.Fingerprint()
+		if err != nil {
+			return nil, fmt.Errorf("error computing issuer: %v", err)
+		}
+		t.issuer = issuer
+	} else {
+		if !t.IsCA {
+			return nil, fmt.Errorf("self signed certificates must have IsCA set to true")
+		}
+	}
+
+	var c beingSignedCertificate
+	switch t.Version {
+	case Version1:
+		c = &certificateV1{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	case Version2:
+		c = &certificateV2{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	default:
+		return nil, fmt.Errorf("unknown cert version %d", t.Version)
+	}
+
+	certBytes, err := c.marshalForSigning()
+	if err != nil {
+		return nil, err
+	}
+
+	sig, err := sp(certBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	err = c.setSignature(sig)
+	if err != nil {
+		return nil, err
+	}
+
+	sc, ok := c.(Certificate)
+	if !ok {
+		return nil, fmt.Errorf("invalid certificate")
+	}
+
+	return sc, nil
+}
+
+func comparePrefix(a, b netip.Prefix) int {
+	addr := a.Addr().Compare(b.Addr())
+	if addr == 0 {
+		return a.Bits() - b.Bits()
+	}
+	return addr
+}
+
+// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
+func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
+	if len(sortedPrefixes) < 2 {
+		return nil
+	}
+	for i := 1; i < len(sortedPrefixes); i++ {
+		if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
+			return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
+		}
+	}
+	return nil
+}

+ 90 - 0
cert/sign_test.go

@@ -0,0 +1,90 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCertificateV1_Sign(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/24"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+	}
+
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
+	assert.Nil(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	assert.Nil(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, uc)
+}
+
+func TestCertificateV1_SignP256(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/16"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+		Curve:     Curve_P256,
+	}
+
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	assert.NoError(t, err)
+	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
+	rawPriv := priv.D.FillBytes(make([]byte, 32))
+
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
+	assert.Nil(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	assert.Nil(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, uc)
+}

+ 138 - 0
cert_test/cert.go

@@ -0,0 +1,138 @@
+package cert_test
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case cert.Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &cert.TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case cert.Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &cert.TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 137 - 73
cmd/nebula-cert/ca.go

@@ -8,13 +8,14 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"math"
 	"math"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
 	"github.com/skip2/go-qrcode"
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/ed25519"
 	"golang.org/x/crypto/ed25519"
 )
 )
 
 
@@ -26,32 +27,43 @@ type caFlags struct {
 	outCertPath      *string
 	outCertPath      *string
 	outQRPath        *string
 	outQRPath        *string
 	groups           *string
 	groups           *string
-	ips              *string
-	subnets          *string
+	networks         *string
+	unsafeNetworks   *string
 	argonMemory      *uint
 	argonMemory      *uint
 	argonIterations  *uint
 	argonIterations  *uint
 	argonParallelism *uint
 	argonParallelism *uint
 	encryption       *bool
 	encryption       *bool
+	version          *uint
 
 
-	curve *string
+	curve  *string
+	p11url *string
+
+	// Deprecated options
+	ips     *string
+	subnets *string
 }
 }
 
 
 func newCaFlags() *caFlags {
 func newCaFlags() *caFlags {
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf.set.Usage = func() {}
 	cf.set.Usage = func() {}
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
+	cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
-	cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
-	cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
+	cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
+	cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
+
+	cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
+	cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &cf
 	return &cf
 }
 }
 
 
@@ -76,17 +88,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return err
 		return err
 	}
 	}
 
 
+	isP11 := len(*cf.p11url) > 0
+
 	if err := mustFlagString("name", cf.name); err != nil {
 	if err := mustFlagString("name", cf.name); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
 	}
 	if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
 	if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
 		return err
 		return err
 	}
 	}
 	var kdfParams *cert.Argon2Parameters
 	var kdfParams *cert.Argon2Parameters
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil {
 		if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil {
 			return err
 			return err
 		}
 		}
@@ -106,44 +122,57 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		}
 		}
 	}
 	}
 
 
-	var ips []*net.IPNet
-	if *cf.ips != "" {
-		for _, rs := range strings.Split(*cf.ips, ",") {
+	version := cert.Version(*cf.version)
+	if version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
+	var networks []netip.Prefix
+	if *cf.networks == "" && *cf.ips != "" {
+		// Pull up deprecated -ips flag if needed
+		*cf.networks = *cf.ips
+	}
+
+	if *cf.networks != "" {
+		for _, rs := range strings.Split(*cf.networks, ",") {
 			rs := strings.Trim(rs, " ")
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 			if rs != "" {
-				ip, ipNet, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
 				if err != nil {
-					return newHelpErrorf("invalid ip definition: %s", err)
+					return newHelpErrorf("invalid -networks definition: %s", rs)
 				}
 				}
-				if ip.To4() == nil {
-					return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
 				}
-
-				ipNet.IP = ip
-				ips = append(ips, ipNet)
+				networks = append(networks, n)
 			}
 			}
 		}
 		}
 	}
 	}
 
 
-	var subnets []*net.IPNet
-	if *cf.subnets != "" {
-		for _, rs := range strings.Split(*cf.subnets, ",") {
+	var unsafeNetworks []netip.Prefix
+	if *cf.unsafeNetworks == "" && *cf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*cf.unsafeNetworks = *cf.subnets
+	}
+
+	if *cf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 			if rs != "" {
-				_, s, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
 				}
-				if s.IP.To4() == nil {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
 				}
-				subnets = append(subnets, s)
+				unsafeNetworks = append(unsafeNetworks, n)
 			}
 			}
 		}
 		}
 	}
 	}
 
 
 	var passphrase []byte
 	var passphrase []byte
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		for i := 0; i < 5; i++ {
 		for i := 0; i < 5; i++ {
 			out.Write([]byte("Enter passphrase: "))
 			out.Write([]byte("Enter passphrase: "))
 			passphrase, err = pr.ReadPassword()
 			passphrase, err = pr.ReadPassword()
@@ -166,74 +195,109 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 
 
 	var curve cert.Curve
 	var curve cert.Curve
 	var pub, rawPriv []byte
 	var pub, rawPriv []byte
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		curve = cert.Curve_CURVE25519
-		pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+
+		p11Client, err = pkclient.FromUrl(*cf.p11url)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
 		}
 		}
-	case "P256":
-		var key *ecdsa.PrivateKey
-		curve = cert.Curve_P256
-		key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
 		}
 		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			curve = cert.Curve_CURVE25519
+			pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			}
+		case "P256":
+			var key *ecdsa.PrivateKey
+			curve = cert.Curve_P256
+			key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			}
 
 
-		// ecdh.PrivateKey lets us get at the encoded bytes, even though
-		// we aren't using ECDH here.
-		eKey, err := key.ECDH()
-		if err != nil {
-			return fmt.Errorf("error while converting ecdsa key: %s", err)
+			// ecdh.PrivateKey lets us get at the encoded bytes, even though
+			// we aren't using ECDH here.
+			eKey, err := key.ECDH()
+			if err != nil {
+				return fmt.Errorf("error while converting ecdsa key: %s", err)
+			}
+			rawPriv = eKey.Bytes()
+			pub = eKey.PublicKey().Bytes()
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
 		}
 		}
-		rawPriv = eKey.Bytes()
-		pub = eKey.PublicKey().Bytes()
 	}
 	}
 
 
-	nc := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      *cf.name,
-			Groups:    groups,
-			Ips:       ips,
-			Subnets:   subnets,
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(*cf.duration),
-			PublicKey: pub,
-			IsCA:      true,
-			Curve:     curve,
-		},
+	t := &cert.TBSCertificate{
+		Version:        version,
+		Name:           *cf.name,
+		Groups:         groups,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		NotBefore:      time.Now(),
+		NotAfter:       time.Now().Add(*cf.duration),
+		PublicKey:      pub,
+		IsCA:           true,
+		Curve:          curve,
 	}
 	}
 
 
-	if _, err := os.Stat(*cf.outKeyPath); err == nil {
-		return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+	if !isP11 {
+		if _, err := os.Stat(*cf.outKeyPath); err == nil {
+			return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+		}
 	}
 	}
 
 
 	if _, err := os.Stat(*cf.outCertPath); err == nil {
 	if _, err := os.Stat(*cf.outCertPath); err == nil {
 		return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
 		return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
 	}
 	}
 
 
-	err = nc.Sign(curve, rawPriv)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
-	}
-
+	var c cert.Certificate
 	var b []byte
 	var b []byte
-	if *cf.encryption {
-		b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+
+	if isP11 {
+		c, err = t.SignWith(nil, curve, p11Client.SignASN1)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("error while encrypting out-key: %s", err)
+			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 		}
 		}
 	} else {
 	} else {
-		b = cert.MarshalSigningPrivateKey(curve, rawPriv)
-	}
+		c, err = t.Sign(nil, curve, rawPriv)
+		if err != nil {
+			return fmt.Errorf("error while signing: %s", err)
+		}
 
 
-	err = os.WriteFile(*cf.outKeyPath, b, 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+		if *cf.encryption {
+			b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+			if err != nil {
+				return fmt.Errorf("error while encrypting out-key: %s", err)
+			}
+		} else {
+			b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
+		}
+
+		err = os.WriteFile(*cf.outKeyPath, b, 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
 	}
 
 
-	b, err = nc.MarshalToPEM()
+	b, err = c.MarshalPEM()
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}
 	}

+ 36 - 30
cmd/nebula-cert/ca_test.go

@@ -16,8 +16,6 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-//TODO: test file permissions
-
 func Test_caSummary(t *testing.T) {
 func Test_caSummary(t *testing.T) {
 	assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
 	assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
 }
 }
@@ -43,17 +41,24 @@ func Test_caHelp(t *testing.T) {
 			"  -groups string\n"+
 			"  -groups string\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"  -ips string\n"+
 			"  -ips string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
+			"    	Deprecated, see -networks\n"+
 			"  -name string\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the certificate authority\n"+
 			"    \tRequired: name of the certificate authority\n"+
+			"  -networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
 			"  -out-crt string\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"  -out-key string\n"+
 			"  -out-key string\n"+
 			"    \tOptional: path to write the private key to (default \"ca.key\")\n"+
 			"    \tOptional: path to write the private key to (default \"ca.key\")\n"+
 			"  -out-qr string\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use (default 2)\n",
 		ob.String(),
 		ob.String(),
 	)
 	)
 }
 }
@@ -82,25 +87,25 @@ func Test_ca(t *testing.T) {
 
 
 	// required args
 	// required args
 	assertHelpError(t, ca(
 	assertHelpError(t, ca(
-		[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
+		[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
 	), "-name is required")
 	), "-name is required")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// ipv4 only ips
 	// ipv4 only ips
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// ipv4 only subnets
 	// ipv4 only subnets
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// failed key write
 	// failed key write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
+	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -108,12 +113,12 @@ func Test_ca(t *testing.T) {
 	// create temp key file
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
 	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	os.Remove(keyF.Name())
+	assert.Nil(t, os.Remove(keyF.Name()))
 
 
 	// failed cert write
 	// failed cert write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -121,45 +126,46 @@ func Test_ca(t *testing.T) {
 	// create temp cert file
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
 	crtF, err := os.CreateTemp("", "test.crt")
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	os.Remove(crtF.Name())
-	os.Remove(keyF.Name())
+	assert.Nil(t, os.Remove(crtF.Name()))
+	assert.Nil(t, os.Remove(keyF.Name()))
 
 
 	// test proper cert with removed empty groups and subnets
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// read cert and key files
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
+	lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, c)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 64)
 	assert.Len(t, lKey, 64)
 
 
 	rb, _ = os.ReadFile(crtF.Name())
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
+	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
-	assert.Equal(t, "test", lCrt.Details.Name)
-	assert.Len(t, lCrt.Details.Ips, 0)
-	assert.True(t, lCrt.Details.IsCA)
-	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
-	assert.Len(t, lCrt.Details.Subnets, 0)
-	assert.Len(t, lCrt.Details.PublicKey, 32)
-	assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
-	assert.Equal(t, "", lCrt.Details.Issuer)
-	assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
+	assert.Equal(t, "test", lCrt.Name())
+	assert.Len(t, lCrt.Networks(), 0)
+	assert.True(t, lCrt.IsCA())
+	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
+	assert.Len(t, lCrt.UnsafeNetworks(), 0)
+	assert.Len(t, lCrt.PublicKey(), 32)
+	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
+	assert.Equal(t, "", lCrt.Issuer())
+	assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
 
 
 	// test encrypted key
 	// test encrypted key
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, testpw))
 	assert.Nil(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -187,7 +193,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Error(t, ca(args, ob, eb, errpw))
 	assert.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -197,7 +203,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -207,13 +213,13 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Nil(t, ca(args, ob, eb, nopw))
 
 
 	// test that we won't overwrite existing certificate file
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -222,7 +228,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())

+ 49 - 20
cmd/nebula-cert/keygen.go

@@ -6,6 +6,8 @@ import (
 	"io"
 	"io"
 	"os"
 	"os"
 
 
+	"github.com/slackhq/nebula/pkclient"
+
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 )
 )
 
 
@@ -13,8 +15,8 @@ type keygenFlags struct {
 	set        *flag.FlagSet
 	set        *flag.FlagSet
 	outKeyPath *string
 	outKeyPath *string
 	outPubPath *string
 	outPubPath *string
-
-	curve *string
+	curve      *string
+	p11url     *string
 }
 }
 
 
 func newKeygenFlags() *keygenFlags {
 func newKeygenFlags() *keygenFlags {
@@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags {
 	cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
 	cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
 	cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
 	cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
 	cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
 	cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
 	return &cf
 	return &cf
 }
 }
 
 
@@ -33,32 +36,58 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 		return err
 	}
 	}
 
 
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	isP11 := len(*cf.p11url) > 0
+
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
 	}
-	if err := mustFlagString("out-pub", cf.outPubPath); err != nil {
+	if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	var pub, rawPriv []byte
 	var pub, rawPriv []byte
 	var curve cert.Curve
 	var curve cert.Curve
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		pub, rawPriv = x25519Keypair()
-		curve = cert.Curve_CURVE25519
-	case "P256":
-		pub, rawPriv = p256Keypair()
-		curve = cert.Curve_P256
-	default:
-		return fmt.Errorf("invalid curve: %s", *cf.curve)
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			pub, rawPriv = x25519Keypair()
+			curve = cert.Curve_CURVE25519
+		case "P256":
+			pub, rawPriv = p256Keypair()
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
+		}
 	}
 	}
 
 
-	err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+	if isP11 {
+		p11Client, err := pkclient.FromUrl(*cf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key: %w", err)
+		}
+	} else {
+		err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
 	}
-
-	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
+	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
 		return fmt.Errorf("error while writing out-pub: %s", err)
 	}
 	}
@@ -72,7 +101,7 @@ func keygenSummary() string {
 
 
 func keygenHelp(out io.Writer) {
 func keygenHelp(out io.Writer) {
 	cf := newKeygenFlags()
 	cf := newKeygenFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
 	cf.set.SetOutput(out)
 	cf.set.SetOutput(out)
 	cf.set.PrintDefaults()
 	cf.set.PrintDefaults()
 }
 }

+ 6 - 5
cmd/nebula-cert/keygen_test.go

@@ -9,8 +9,6 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-//TODO: test file permissions
-
 func Test_keygenSummary(t *testing.T) {
 func Test_keygenSummary(t *testing.T) {
 	assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
 	assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
 }
 }
@@ -26,7 +24,8 @@ func Test_keygenHelp(t *testing.T) {
 			"  -out-key string\n"+
 			"  -out-key string\n"+
 			"    \tRequired: path to write the private key to\n"+
 			"    \tRequired: path to write the private key to\n"+
 			"  -out-pub string\n"+
 			"  -out-pub string\n"+
-			"    \tRequired: path to write the public key to\n",
+			"    \tRequired: path to write the public key to\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n"),
 		ob.String(),
 		ob.String(),
 	)
 	)
 }
 }
@@ -80,13 +79,15 @@ func Test_keygen(t *testing.T) {
 
 
 	// read cert and key files
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
+	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 	assert.Len(t, lKey, 32)
 
 
 	rb, _ = os.ReadFile(pubF.Name())
 	rb, _ = os.ReadFile(pubF.Name())
-	lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
+	lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, lPub, 32)
 	assert.Len(t, lPub, 32)

+ 10 - 3
cmd/nebula-cert/main_test.go

@@ -3,6 +3,7 @@ package main
 import (
 import (
 	"bytes"
 	"bytes"
 	"errors"
 	"errors"
+	"fmt"
 	"io"
 	"io"
 	"os"
 	"os"
 	"testing"
 	"testing"
@@ -10,8 +11,6 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-//TODO: all flag parsing continueOnError will print to stderr on its own currently
-
 func Test_help(t *testing.T) {
 func Test_help(t *testing.T) {
 	expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
 	expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
 		"  Global flags:\n" +
 		"  Global flags:\n" +
@@ -77,8 +76,16 @@ func assertHelpError(t *testing.T, err error, msg string) {
 	case *helpError:
 	case *helpError:
 		// good
 		// good
 	default:
 	default:
-		t.Fatal("err was not a helpError")
+		t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
 	}
 	}
 
 
 	assert.EqualError(t, err, msg)
 	assert.EqualError(t, err, msg)
 }
 }
+
+func optionalPkcs11String(msg string) string {
+	if p11Supported() {
+		return msg
+	} else {
+		return ""
+	}
+}

+ 15 - 0
cmd/nebula-cert/p11_cgo.go

@@ -0,0 +1,15 @@
+//go:build cgo && pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return true
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key")
+}

+ 16 - 0
cmd/nebula-cert/p11_stub.go

@@ -0,0 +1,16 @@
+//go:build !cgo || !pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return false
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	var ret = ""
+	return &ret
+}

+ 14 - 9
cmd/nebula-cert/print.go

@@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("unable to read cert; %s", err)
 		return fmt.Errorf("unable to read cert; %s", err)
 	}
 	}
 
 
-	var c *cert.NebulaCertificate
+	var c cert.Certificate
 	var qrBytes []byte
 	var qrBytes []byte
 	part := 0
 	part := 0
 
 
+	var jsonCerts []cert.Certificate
+
 	for {
 	for {
-		c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
+		c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("error while unmarshaling cert: %s", err)
 			return fmt.Errorf("error while unmarshaling cert: %s", err)
 		}
 		}
 
 
 		if *pf.json {
 		if *pf.json {
-			b, _ := json.Marshal(c)
-			out.Write(b)
-			out.Write([]byte("\n"))
-
+			jsonCerts = append(jsonCerts, c)
 		} else {
 		} else {
-			out.Write([]byte(c.String()))
-			out.Write([]byte("\n"))
+			_, _ = out.Write([]byte(c.String()))
+			_, _ = out.Write([]byte("\n"))
 		}
 		}
 
 
 		if *pf.outQRPath != "" {
 		if *pf.outQRPath != "" {
-			b, err := c.MarshalToPEM()
+			b, err := c.MarshalPEM()
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("error while marshalling cert to PEM: %s", err)
 				return fmt.Errorf("error while marshalling cert to PEM: %s", err)
 			}
 			}
@@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		part++
 		part++
 	}
 	}
 
 
+	if *pf.json {
+		b, _ := json.Marshal(jsonCerts)
+		_, _ = out.Write(b)
+		_, _ = out.Write([]byte("\n"))
+	}
+
 	if *pf.outQRPath != "" {
 	if *pf.outQRPath != "" {
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		if err != nil {
 		if err != nil {

+ 144 - 21
cmd/nebula-cert/print_test.go

@@ -2,6 +2,10 @@ package main
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
 	"os"
 	"os"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -68,25 +72,86 @@ func Test_printCert(t *testing.T) {
 	eb.Reset()
 	eb.Reset()
 	tf.Truncate(0)
 	tf.Truncate(0)
 	tf.Seek(0, 0)
 	tf.Seek(0, 0)
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test",
-			Groups:    []string{"hi"},
-			PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-		},
-		Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-	}
+	ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
+	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
 
 
-	p, _ := c.MarshalToPEM()
+	p, _ := c.MarshalPEM()
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 
 
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
+	fp, _ := c.Fingerprint()
+	pk := hex.EncodeToString(c.PublicKey())
+	sig := hex.EncodeToString(c.Signature())
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(
 	assert.Equal(
 		t,
 		t,
-		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
+		//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		`{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+`,
 		ob.String(),
 		ob.String(),
 	)
 	)
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -96,26 +161,84 @@ func Test_printCert(t *testing.T) {
 	eb.Reset()
 	eb.Reset()
 	tf.Truncate(0)
 	tf.Truncate(0)
 	tf.Seek(0, 0)
 	tf.Seek(0, 0)
-	c = cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test",
-			Groups:    []string{"hi"},
-			PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-		},
-		Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-	}
-
-	p, _ = c.MarshalToPEM()
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 
 
 	err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
 	err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
+	fp, _ = c.Fingerprint()
+	pk = hex.EncodeToString(c.PublicKey())
+	sig = hex.EncodeToString(c.Signature())
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(
 	assert.Equal(
 		t,
 		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
+		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
+`,
 		ob.String(),
 		ob.String(),
 	)
 	)
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 }
 }
+
+// NewTestCaCert will generate a CA cert
+func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
+	var err error
+	if pubKey == nil || privKey == nil {
+		pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+	}
+
+	t := &cert.TBSCertificate{
+		Version:        cert.Version1,
+		Name:           name,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pubKey,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey)
+	if err != nil {
+		panic(err)
+	}
+
+	return c, privKey
+}
+
+func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
+	if before.IsZero() {
+		before = ca.NotBefore()
+	}
+
+	if after.IsZero() {
+		after = ca.NotAfter()
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	pub, rawPriv := x25519Keypair()
+	nc := &cert.TBSCertificate{
+		Version:        cert.Version1,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), signerKey)
+	if err != nil {
+		panic(err)
+	}
+
+	return c, rawPriv
+}

+ 238 - 110
cmd/nebula-cert/sign.go

@@ -3,50 +3,63 @@ package main
 import (
 import (
 	"crypto/ecdh"
 	"crypto/ecdh"
 	"crypto/rand"
 	"crypto/rand"
+	"errors"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
 	"github.com/skip2/go-qrcode"
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/curve25519"
 )
 )
 
 
 type signFlags struct {
 type signFlags struct {
-	set         *flag.FlagSet
-	caKeyPath   *string
-	caCertPath  *string
-	name        *string
-	ip          *string
-	duration    *time.Duration
-	inPubPath   *string
-	outKeyPath  *string
-	outCertPath *string
-	outQRPath   *string
-	groups      *string
-	subnets     *string
+	set            *flag.FlagSet
+	version        *uint
+	caKeyPath      *string
+	caCertPath     *string
+	name           *string
+	networks       *string
+	unsafeNetworks *string
+	duration       *time.Duration
+	inPubPath      *string
+	outKeyPath     *string
+	outCertPath    *string
+	outQRPath      *string
+	groups         *string
+
+	p11url *string
+
+	// Deprecated options
+	ip      *string
+	subnets *string
 }
 }
 
 
 func newSignFlags() *signFlags {
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
 	sf.set.Usage = func() {}
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
-	sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
+	sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
+	sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
-	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
-	return &sf
+	sf.p11url = p11Flag(sf.set)
 
 
+	sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
+	sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
+	return &sf
 }
 }
 
 
 func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
 func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
@@ -56,8 +69,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return err
 		return err
 	}
 	}
 
 
-	if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
-		return err
+	isP11 := len(*sf.p11url) > 0
+
+	if !isP11 {
+		if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
+			return err
+		}
 	}
 	}
 	if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
 	if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
 		return err
 		return err
@@ -65,50 +82,67 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	if err := mustFlagString("name", sf.name); err != nil {
 	if err := mustFlagString("name", sf.name); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := mustFlagString("ip", sf.ip); err != nil {
-		return err
-	}
-	if *sf.inPubPath != "" && *sf.outKeyPath != "" {
+	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 	}
 
 
-	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
-	if err != nil {
-		return fmt.Errorf("error while reading ca-key: %s", err)
+	var v4Networks []netip.Prefix
+	var v6Networks []netip.Prefix
+	if *sf.networks == "" && *sf.ip != "" {
+		// Pull up deprecated -ip flag if needed
+		*sf.networks = *sf.ip
+	}
+
+	if len(*sf.networks) == 0 {
+		return newHelpErrorf("-networks is required")
+	}
+
+	version := cert.Version(*sf.version)
+	if version != 0 && version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
 	}
 	}
 
 
 	var curve cert.Curve
 	var curve cert.Curve
 	var caKey []byte
 	var caKey []byte
 
 
-	// naively attempt to decode the private key as though it is not encrypted
-	caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
-	if err == cert.ErrPrivateKeyEncrypted {
-		// ask for a passphrase until we get one
-		var passphrase []byte
-		for i := 0; i < 5; i++ {
-			out.Write([]byte("Enter passphrase: "))
-			passphrase, err = pr.ReadPassword()
-
-			if err == ErrNoTerminal {
-				return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
-			} else if err != nil {
-				return fmt.Errorf("error reading password: %s", err)
-			}
+	if !isP11 {
+		var rawCAKey []byte
+		rawCAKey, err := os.ReadFile(*sf.caKeyPath)
 
 
-			if len(passphrase) > 0 {
-				break
-			}
-		}
-		if len(passphrase) == 0 {
-			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+		if err != nil {
+			return fmt.Errorf("error while reading ca-key: %s", err)
 		}
 		}
 
 
-		curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
-		if err != nil {
-			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+		// naively attempt to decode the private key as though it is not encrypted
+		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
+		if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
+			// ask for a passphrase until we get one
+			var passphrase []byte
+			for i := 0; i < 5; i++ {
+				out.Write([]byte("Enter passphrase: "))
+				passphrase, err = pr.ReadPassword()
+
+				if errors.Is(err, ErrNoTerminal) {
+					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
+				} else if err != nil {
+					return fmt.Errorf("error reading password: %s", err)
+				}
+
+				if len(passphrase) > 0 {
+					break
+				}
+			}
+			if len(passphrase) == 0 {
+				return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+			}
+
+			curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
+			if err != nil {
+				return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+			}
+		} else if err != nil {
+			return fmt.Errorf("error while parsing ca-key: %s", err)
 		}
 		}
-	} else if err != nil {
-		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 	}
 
 
 	rawCACert, err := os.ReadFile(*sf.caCertPath)
 	rawCACert, err := os.ReadFile(*sf.caCertPath)
@@ -116,18 +150,15 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while reading ca-crt: %s", err)
 		return fmt.Errorf("error while reading ca-crt: %s", err)
 	}
 	}
 
 
-	caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
+	caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("error while parsing ca-crt: %s", err)
 		return fmt.Errorf("error while parsing ca-crt: %s", err)
 	}
 	}
 
 
-	if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate does not match private key")
-	}
-
-	issuer, err := caCert.Sha256Sum()
-	if err != nil {
-		return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
+	if !isP11 {
+		if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
+			return fmt.Errorf("refusing to sign, root certificate does not match private key")
+		}
 	}
 	}
 
 
 	if caCert.Expired(time.Now()) {
 	if caCert.Expired(time.Now()) {
@@ -136,82 +167,99 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 
 	// if no duration is given, expire one second before the root expires
 	// if no duration is given, expire one second before the root expires
 	if *sf.duration <= 0 {
 	if *sf.duration <= 0 {
-		*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
+		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
 	}
 	}
 
 
-	ip, ipNet, err := net.ParseCIDR(*sf.ip)
-	if err != nil {
-		return newHelpErrorf("invalid ip definition: %s", err)
-	}
-	if ip.To4() == nil {
-		return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
-	}
-	ipNet.IP = ip
+	if *sf.networks != "" {
+		for _, rs := range strings.Split(*sf.networks, ",") {
+			rs := strings.Trim(rs, " ")
+			if rs != "" {
+				n, err := netip.ParsePrefix(rs)
+				if err != nil {
+					return newHelpErrorf("invalid -networks definition: %s", rs)
+				}
 
 
-	groups := []string{}
-	if *sf.groups != "" {
-		for _, rg := range strings.Split(*sf.groups, ",") {
-			g := strings.TrimSpace(rg)
-			if g != "" {
-				groups = append(groups, g)
+				if n.Addr().Is4() {
+					v4Networks = append(v4Networks, n)
+				} else {
+					v6Networks = append(v6Networks, n)
+				}
 			}
 			}
 		}
 		}
 	}
 	}
 
 
-	subnets := []*net.IPNet{}
-	if *sf.subnets != "" {
-		for _, rs := range strings.Split(*sf.subnets, ",") {
+	var v4UnsafeNetworks []netip.Prefix
+	var v6UnsafeNetworks []netip.Prefix
+	if *sf.unsafeNetworks == "" && *sf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*sf.unsafeNetworks = *sf.subnets
+	}
+
+	if *sf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 			if rs != "" {
-				_, s, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
 				}
-				if s.IP.To4() == nil {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+
+				if n.Addr().Is4() {
+					v4UnsafeNetworks = append(v4UnsafeNetworks, n)
+				} else {
+					v6UnsafeNetworks = append(v6UnsafeNetworks, n)
 				}
 				}
-				subnets = append(subnets, s)
+			}
+		}
+	}
+
+	var groups []string
+	if *sf.groups != "" {
+		for _, rg := range strings.Split(*sf.groups, ",") {
+			g := strings.TrimSpace(rg)
+			if g != "" {
+				groups = append(groups, g)
 			}
 			}
 		}
 		}
 	}
 	}
 
 
 	var pub, rawPriv []byte
 	var pub, rawPriv []byte
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		curve = cert.Curve_P256
+		p11Client, err = pkclient.FromUrl(*sf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+	}
+
 	if *sf.inPubPath != "" {
 	if *sf.inPubPath != "" {
+		var pubCurve cert.Curve
 		rawPub, err := os.ReadFile(*sf.inPubPath)
 		rawPub, err := os.ReadFile(*sf.inPubPath)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
 		}
-		var pubCurve cert.Curve
-		pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
+
+		pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("error while parsing in-pub: %s", err)
 			return fmt.Errorf("error while parsing in-pub: %s", err)
 		}
 		}
 		if pubCurve != curve {
 		if pubCurve != curve {
 			return fmt.Errorf("curve of in-pub does not match ca")
 			return fmt.Errorf("curve of in-pub does not match ca")
 		}
 		}
+	} else if isP11 {
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
+		}
 	} else {
 	} else {
 		pub, rawPriv = newKeypair(curve)
 		pub, rawPriv = newKeypair(curve)
 	}
 	}
 
 
-	nc := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      *sf.name,
-			Ips:       []*net.IPNet{ipNet},
-			Groups:    groups,
-			Subnets:   subnets,
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(*sf.duration),
-			PublicKey: pub,
-			IsCA:      false,
-			Issuer:    issuer,
-			Curve:     curve,
-		},
-	}
-
-	if err := nc.CheckRootConstrains(caCert); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err)
-	}
-
 	if *sf.outKeyPath == "" {
 	if *sf.outKeyPath == "" {
 		*sf.outKeyPath = *sf.name + ".key"
 		*sf.outKeyPath = *sf.name + ".key"
 	}
 	}
@@ -224,25 +272,105 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 	}
 
 
-	err = nc.Sign(curve, caKey)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
+	var crts []cert.Certificate
+
+	notBefore := time.Now()
+	notAfter := notBefore.Add(*sf.duration)
+
+	if version == 0 || version == cert.Version1 {
+		// Make sure we at least have an ip
+		if len(v4Networks) != 1 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
+		}
+
+		if version == cert.Version1 {
+			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
+			if len(v6Networks) > 0 {
+				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
+			}
+
+			if len(v6UnsafeNetworks) > 0 {
+				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+			}
+		}
+
+		t := &cert.TBSCertificate{
+			Version:        cert.Version1,
+			Name:           *sf.name,
+			Networks:       []netip.Prefix{v4Networks[0]},
+			Groups:         groups,
+			UnsafeNetworks: v4UnsafeNetworks,
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
+	}
+
+	if version == 0 || version == cert.Version2 {
+		t := &cert.TBSCertificate{
+			Version:        cert.Version2,
+			Name:           *sf.name,
+			Networks:       append(v4Networks, v6Networks...),
+			Groups:         groups,
+			UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
 	}
 	}
 
 
-	if *sf.inPubPath == "" {
+	if !isP11 && *sf.inPubPath == "" {
 		if _, err := os.Stat(*sf.outKeyPath); err == nil {
 		if _, err := os.Stat(*sf.outKeyPath); err == nil {
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}
 		}
 
 
-		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("error while writing out-key: %s", err)
 			return fmt.Errorf("error while writing out-key: %s", err)
 		}
 		}
 	}
 	}
 
 
-	b, err := nc.MarshalToPEM()
-	if err != nil {
-		return fmt.Errorf("error while marshalling certificate: %s", err)
+	var b []byte
+	for _, c := range crts {
+		sb, err := c.MarshalPEM()
+		if err != nil {
+			return fmt.Errorf("error while marshalling certificate: %s", err)
+		}
+		b = append(b, sb...)
 	}
 	}
 
 
 	err = os.WriteFile(*sf.outCertPath, b, 0600)
 	err = os.WriteFile(*sf.outCertPath, b, 0600)

+ 74 - 75
cmd/nebula-cert/sign_test.go

@@ -16,8 +16,6 @@ import (
 	"golang.org/x/crypto/ed25519"
 	"golang.org/x/crypto/ed25519"
 )
 )
 
 
-//TODO: test file permissions
-
 func Test_signSummary(t *testing.T) {
 func Test_signSummary(t *testing.T) {
 	assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
 	assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
 }
 }
@@ -39,17 +37,24 @@ func Test_signHelp(t *testing.T) {
 			"  -in-pub string\n"+
 			"  -in-pub string\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"  -ip string\n"+
 			"  -ip string\n"+
-			"    \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
+			"    \tDeprecated, see -networks\n"+
 			"  -name string\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
+			"  -networks string\n"+
+			"    \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
 			"  -out-crt string\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"  -out-key string\n"+
 			"  -out-key string\n"+
 			"    \tOptional (if in-pub not set): path to write the private key to\n"+
 			"    \tOptional (if in-pub not set): path to write the private key to\n"+
 			"  -out-qr string\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
 		ob.String(),
 		ob.String(),
 	)
 	)
 }
 }
@@ -76,20 +81,20 @@ func Test_signCert(t *testing.T) {
 
 
 	// required args
 	// required args
 	assertHelpError(t, signCert(
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
 	), "-name is required")
 	), "-name is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	assertHelpError(t, signCert(
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
-	), "-ip is required")
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-networks is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// cannot set -in-pub and -out-key
 	// cannot set -in-pub and -out-key
 	assertHelpError(t, signCert(
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
 	), "cannot set both -in-pub and -out-key")
 	), "cannot set both -in-pub and -out-key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -97,7 +102,7 @@ func Test_signCert(t *testing.T) {
 	// failed to read key
 	// failed to read key
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 
 	// failed to unmarshal key
 	// failed to unmarshal key
@@ -107,7 +112,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 	defer os.Remove(caKeyF.Name())
 
 
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -116,10 +121,10 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
-	caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv))
+	caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
 
 
 	// failed to read cert
 	// failed to read cert
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -131,26 +136,18 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 	defer os.Remove(caCrtF.Name())
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// write a proper ca cert for later
 	// write a proper ca cert for later
-	ca := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(time.Minute * 200),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	b, _ := ca.MarshalToPEM()
+	ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
+	b, _ := ca.MarshalPEM()
 	caCrtF.Write(b)
 	caCrtF.Write(b)
 
 
 	// failed to read pub
 	// failed to read pub
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -162,7 +159,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(inPubF.Name())
 	defer os.Remove(inPubF.Name())
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -171,35 +168,42 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
 	inPub, _ := x25519Keypair()
 	inPub, _ := x25519Keypair()
-	inPubF.Write(cert.MarshalX25519PublicKey(inPub))
+	inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub))
 
 
 	// bad ip cidr
 	// bad ip cidr
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// bad subnet cidr
 	// bad subnet cidr
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
@@ -208,11 +212,11 @@ func Test_signCert(t *testing.T) {
 	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
 	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF2.Name())
 	defer os.Remove(caKeyF2.Name())
-	caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
+	caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
 
 
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -220,7 +224,7 @@ func Test_signCert(t *testing.T) {
 	// failed key write
 	// failed key write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -233,7 +237,7 @@ func Test_signCert(t *testing.T) {
 	// failed cert write
 	// failed cert write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -247,40 +251,41 @@ func Test_signCert(t *testing.T) {
 	// test proper cert with removed empty groups and subnets
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// read cert and key files
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
+	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 	assert.Len(t, lKey, 32)
 
 
 	rb, _ = os.ReadFile(crtF.Name())
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
+	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
-	assert.Equal(t, "test", lCrt.Details.Name)
-	assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String())
-	assert.Len(t, lCrt.Details.Ips, 1)
-	assert.False(t, lCrt.Details.IsCA)
-	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
-	assert.Len(t, lCrt.Details.Subnets, 3)
-	assert.Len(t, lCrt.Details.PublicKey, 32)
-	assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
+	assert.Equal(t, "test", lCrt.Name())
+	assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
+	assert.Len(t, lCrt.Networks(), 1)
+	assert.False(t, lCrt.IsCA())
+	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
+	assert.Len(t, lCrt.UnsafeNetworks(), 3)
+	assert.Len(t, lCrt.PublicKey(), 32)
+	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
 
 
 	sns := []string{}
 	sns := []string{}
-	for _, sn := range lCrt.Details.Subnets {
+	for _, sn := range lCrt.UnsafeNetworks() {
 		sns = append(sns, sn.String())
 		sns = append(sns, sn.String())
 	}
 	}
 	assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
 	assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
 
 
-	issuer, _ := ca.Sha256Sum()
-	assert.Equal(t, issuer, lCrt.Details.Issuer)
+	issuer, _ := ca.Fingerprint()
+	assert.Equal(t, issuer, lCrt.Issuer())
 
 
 	assert.True(t, lCrt.CheckSignature(caPub))
 	assert.True(t, lCrt.CheckSignature(caPub))
 
 
@@ -289,37 +294,39 @@ func Test_signCert(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// read cert file and check pub key matches in-pub
 	// read cert file and check pub key matches in-pub
 	rb, _ = os.ReadFile(crtF.Name())
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
+	lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, lCrt.Details.PublicKey, inPub)
+	assert.Equal(t, lCrt.PublicKey(), inPub)
 
 
 	// test refuse to sign cert with duration beyond root
 	// test refuse to sign cert with duration beyond root
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// create valid cert/key for overwrite tests
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 
 	// test that we won't overwrite existing key file
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -327,14 +334,14 @@ func Test_signCert(t *testing.T) {
 	// create valid cert/key for overwrite tests
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 
 	// test that we won't overwrite existing certificate file
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -361,20 +368,12 @@ func Test_signCert(t *testing.T) {
 	b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
 	b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
 	caKeyF.Write(b)
 	caKeyF.Write(b)
 
 
-	ca = cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(time.Minute * 200),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	b, _ = ca.MarshalToPEM()
+	ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
+	b, _ = ca.MarshalPEM()
 	caCrtF.Write(b)
 	caCrtF.Write(b)
 
 
 	// test with the proper password
 	// test with the proper password
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, testpw))
 	assert.Nil(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -384,7 +383,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 	eb.Reset()
 
 
 	testpw.password = []byte("invalid password")
 	testpw.password = []byte("invalid password")
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, testpw))
 	assert.Error(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -393,7 +392,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, nopw))
 	assert.Error(t, signCert(args, ob, eb, nopw))
 	// normally the user hitting enter on the prompt would add newlines between these
 	// normally the user hitting enter on the prompt would add newlines between these
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@@ -403,7 +402,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, errpw))
 	assert.Error(t, signCert(args, ob, eb, errpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())

+ 26 - 15
cmd/nebula-cert/verify.go

@@ -1,6 +1,7 @@
 package main
 package main
 
 
 import (
 import (
+	"errors"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 
 	rawCACert, err := os.ReadFile(*vf.caPath)
 	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("error while reading ca: %s", err)
+		return fmt.Errorf("error while reading ca: %w", err)
 	}
 	}
 
 
 	caPool := cert.NewCAPool()
 	caPool := cert.NewCAPool()
 	for {
 	for {
-		rawCACert, err = caPool.AddCACertificate(rawCACert)
+		rawCACert, err = caPool.AddCAFromPEM(rawCACert)
 		if err != nil {
 		if err != nil {
-			return fmt.Errorf("error while adding ca cert to pool: %s", err)
+			return fmt.Errorf("error while adding ca cert to pool: %w", err)
 		}
 		}
 
 
 		if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
 		if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 
 	rawCert, err := os.ReadFile(*vf.certPath)
 	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("unable to read crt; %s", err)
+		return fmt.Errorf("unable to read crt: %w", err)
 	}
 	}
-
-	c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
-	if err != nil {
-		return fmt.Errorf("error while parsing crt: %s", err)
-	}
-
-	good, err := c.Verify(time.Now(), caPool)
-	if !good {
-		return err
+	var errs []error
+	for {
+		if len(rawCert) == 0 {
+			break
+		}
+		c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
+		if err != nil {
+			return fmt.Errorf("error while parsing crt: %w", err)
+		}
+		rawCert = extra
+		_, err = caPool.VerifyCertificate(time.Now(), c)
+		if err != nil {
+			switch {
+			case errors.Is(err, cert.ErrCaNotFound):
+				errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
+			default:
+				errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
+			}
+		}
 	}
 	}
 
 
-	return nil
+	return errors.Join(errs...)
 }
 }
 
 
 func verifySummary() string {
 func verifySummary() string {
@@ -80,7 +91,7 @@ func verifySummary() string {
 
 
 func verifyHelp(out io.Writer) {
 func verifyHelp(out io.Writer) {
 	vf := newVerifyFlags()
 	vf := newVerifyFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
 	vf.set.SetOutput(out)
 	vf.set.SetOutput(out)
 	vf.set.PrintDefaults()
 	vf.set.PrintDefaults()
 }
 }

+ 13 - 30
cmd/nebula-cert/verify_test.go

@@ -3,6 +3,7 @@ package main
 import (
 import (
 	"bytes"
 	"bytes"
 	"crypto/rand"
 	"crypto/rand"
+	"errors"
 	"os"
 	"os"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -67,17 +68,8 @@ func Test_verify(t *testing.T) {
 
 
 	// make a ca for later
 	// make a ca for later
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
-	ca := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test-ca",
-			NotBefore: time.Now().Add(time.Hour * -1),
-			NotAfter:  time.Now().Add(time.Hour * 2),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	ca.Sign(cert.Curve_CURVE25519, caPriv)
-	b, _ := ca.MarshalToPEM()
+	ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
+	b, _ := ca.MarshalPEM()
 	caFile.Truncate(0)
 	caFile.Truncate(0)
 	caFile.Seek(0, 0)
 	caFile.Seek(0, 0)
 	caFile.Write(b)
 	caFile.Write(b)
@@ -86,7 +78,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
+	assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 
 	// invalid crt at path
 	// invalid crt at path
 	ob.Reset()
 	ob.Reset()
@@ -102,22 +94,13 @@ func Test_verify(t *testing.T) {
 	assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
 
 
 	// unverifiable cert at path
 	// unverifiable cert at path
-	_, badPriv, _ := ed25519.GenerateKey(rand.Reader)
-	certPub, _ := x25519Keypair()
-	signer, _ := ca.Sha256Sum()
-	crt := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test-cert",
-			NotBefore: time.Now().Add(time.Hour * -1),
-			NotAfter:  time.Now().Add(time.Hour),
-			PublicKey: certPub,
-			IsCA:      false,
-			Issuer:    signer,
-		},
+	crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
+	// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
+	pub := crt.PublicKey()
+	for i, _ := range pub {
+		pub[i] = 0
 	}
 	}
-
-	crt.Sign(cert.Curve_CURVE25519, badPriv)
-	b, _ = crt.MarshalToPEM()
+	b, _ = crt.MarshalPEM()
 	certFile.Truncate(0)
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
 	certFile.Seek(0, 0)
 	certFile.Write(b)
 	certFile.Write(b)
@@ -125,11 +108,11 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "certificate signature did not match")
+	assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
 
 
 	// verified cert at path
 	// verified cert at path
-	crt.Sign(cert.Curve_CURVE25519, caPriv)
-	b, _ = crt.MarshalToPEM()
+	crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
+	b, _ = crt.MarshalPEM()
 	certFile.Truncate(0)
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
 	certFile.Seek(0, 0)
 	certFile.Write(b)
 	certFile.Write(b)

+ 0 - 3
config/config_test.go

@@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) {
 		"new": "hi",
 		"new": "hi",
 	}
 	}
 	assert.Equal(t, expected, c.Settings)
 	assert.Equal(t, expected, c.Settings)
-
-	//TODO: test symlinked file
-	//TODO: test symlinked directory
 }
 }
 
 
 func TestConfig_Get(t *testing.T) {
 func TestConfig_Get(t *testing.T) {

+ 61 - 36
connection_manager.go

@@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 	case deleteTunnel:
 	case deleteTunnel:
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 		}
 
 
 	case closeTunnel:
 	case closeTunnel:
@@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 
 	for _, r := range relayFor {
 	for _, r := range relayFor {
-		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
 
 
 		var index uint32
 		var index uint32
 		var relayFrom netip.Addr
 		var relayFrom netip.Addr
@@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			index = existing.LocalIndex
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = existing.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = existing.PeerAddr
 			case ForwardingType:
 			case ForwardingType:
-				relayFrom = existing.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = existing.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 			default:
 				// should never happen
 				// should never happen
 			}
 			}
@@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			n.relayUsedLock.RUnlock()
 			n.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
 			if err != nil {
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 				continue
 			}
 			}
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = r.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = r.PeerAddr
 			case ForwardingType:
 			case ForwardingType:
-				relayFrom = r.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = r.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 			default:
 				// should never happen
 				// should never happen
 			}
 			}
 		}
 		}
 
 
-		//TODO: IPV6-WORK
-		relayFromB := relayFrom.As4()
-		relayToB := relayTo.As4()
-
 		// Send a CreateRelayRequest to the peer.
 		// Send a CreateRelayRequest to the peer.
 		req := NebulaControl{
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         binary.BigEndian.Uint32(relayFromB[:]),
-			RelayToIp:           binary.BigEndian.Uint32(relayToB[:]),
 		}
 		}
+
+		switch newhostinfo.GetCert().Certificate.Version() {
+		case cert.Version1:
+			if !relayFrom.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				continue
+			}
+
+			if !relayTo.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				continue
+			}
+
+			b := relayFrom.As4()
+			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = relayTo.As4()
+			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		case cert.Version2:
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
+			req.RelayToAddr = netAddrToProtoAddr(relayTo)
+		default:
+			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			continue
+		}
+
 		msg, err := req.Marshal()
 		msg, err := req.Marshal()
 		if err != nil {
 		if err != nil {
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           req.RelayFromIp,
-				"relayTo":             req.RelayToIp,
+				"relayFrom":           req.RelayFromAddr,
+				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
-				"vpnIp":               newhostinfo.vpnIp}).
+				"vpnAddrs":            newhostinfo.vpnAddrs}).
 				Info("send CreateRelayRequest")
 				Info("send CreateRelayRequest")
 		}
 		}
 	}
 	}
@@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		return closeTunnel, hostinfo, nil
 		return closeTunnel, hostinfo, nil
 	}
 	}
 
 
-	primary := n.hostMap.Hosts[hostinfo.vpnIp]
+	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
 		mainHostInfo = false
@@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 	// Let's sort this out.
 
 
-	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
-		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
-		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
-		// The remotes vpn ip is lower than mine. I will not flip.
+	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
+	// vpn addr is static across all tunnels for this host pair so lets
+	// use that to determine if we should consider swapping.
+	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
+		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 		return false
 	}
 	}
 
 
-	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
+	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
+	// settle down.
+	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 }
 
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 	n.hostMap.Lock()
 	n.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnIp] == primary {
+	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
 		n.hostMap.unlockedMakePrimary(current)
 		n.hostMap.unlockedMakePrimary(current)
 	}
 	}
 	n.hostMap.Unlock()
 	n.hostMap.Unlock()
@@ -436,8 +458,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 		return false
 	}
 	}
 
 
-	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
-	if valid {
+	caPool := n.intf.pki.GetCAPool()
+	err := caPool.VerifyCachedCertificate(now, remoteCert)
+	if err == nil {
 		return false
 		return false
 	}
 	}
 
 
@@ -446,9 +469,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 		return false
 	}
 	}
 
 
-	fingerprint, _ := remoteCert.Sha256Sum()
 	hostinfo.logger(n.l).WithError(err).
 	hostinfo.logger(n.l).WithError(err).
-		WithField("fingerprint", fingerprint).
+		WithField("fingerprint", remoteCert.Fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
 
 	return true
 	return true
@@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 }
 
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
+	cs := n.intf.pki.getCertState()
+	curCrt := hostinfo.ConnectionState.myCert
+	myCrt := cs.getCertificate(curCrt.Version())
+	if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 		return
 	}
 	}
 
 
-	n.l.WithField("vpnIp", hostinfo.vpnIp).
+	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 		Info("Re-handshaking with remote")
 
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }
 }

+ 162 - 61
connection_manager_test.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"context"
 	"crypto/ed25519"
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/rand"
-	"net"
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -35,20 +34,19 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
@@ -75,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 		remoteIndexId: 9901,
 	}
 	}
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &cert.NebulaCertificate{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 		H:      &noise.HandshakeState{},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -89,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 
 
@@ -106,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 
 	// Do a final traffic check tick, the host should now be removed
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 }
 
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
@@ -158,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 		remoteIndexId: 9901,
 	}
 	}
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &cert.NebulaCertificate{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 		H:      &noise.HandshakeState{},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -171,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@@ -188,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 
 	// We saw traffic, should no longer be pending deletion
 	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
@@ -197,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
 }
 
 
 // Check if we can disconnect the peer.
 // Check if we can disconnect the peer.
@@ -206,55 +203,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	now := time.Now()
 	now := time.Now()
 	l := test.NewLogger()
 	l := test.NewLogger()
-	ipNet := net.IPNet{
-		IP:   net.IPv4(172, 1, 1, 2),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
+
 	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 
 	// Generate keys for CA and peer's cert.
 	// Generate keys for CA and peer's cert.
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
-	caCert := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: now,
-			NotAfter:  now.Add(1 * time.Hour),
-			IsCA:      true,
-			PublicKey: pubCA,
-		},
+	tbs := &cert.TBSCertificate{
+		Version:   1,
+		Name:      "ca",
+		IsCA:      true,
+		NotBefore: now,
+		NotAfter:  now.Add(1 * time.Hour),
+		PublicKey: pubCA,
 	}
 	}
 
 
-	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
-	ncp := &cert.NebulaCAPool{
-		CAs: cert.NewCAPool().CAs,
-	}
-	ncp.CAs["ca"] = &caCert
+	caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, err)
+	ncp := cert.NewCAPool()
+	assert.NoError(t, ncp.AddCA(caCert))
 
 
 	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
 	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
-	peerCert := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "host",
-			Ips:       []*net.IPNet{&ipNet},
-			Subnets:   []*net.IPNet{},
-			NotBefore: now,
-			NotAfter:  now.Add(60 * time.Second),
-			PublicKey: pubCrt,
-			IsCA:      false,
-			Issuer:    "ca",
-		},
+	tbs = &cert.TBSCertificate{
+		Version:   1,
+		Name:      "host",
+		Networks:  []netip.Prefix{vpncidr},
+		NotBefore: now,
+		NotAfter:  now.Add(60 * time.Second),
+		PublicKey: pubCrt,
 	}
 	}
-	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
+	peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, err)
+
+	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
@@ -280,10 +270,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.connectionManager = nc
 	ifce.connectionManager = nc
 
 
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp: vpnIp,
+		vpnAddrs: []netip.Addr{vpnIp},
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
-			myCert:   &cert.NebulaCertificate{},
-			peerCert: &peerCert,
+			myCert:   &dummyCert{},
+			peerCert: cachedPeerCert,
 			H:        &noise.HandshakeState{},
 			H:        &noise.HandshakeState{},
 		},
 		},
 	}
 	}
@@ -303,3 +293,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
 	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
 	assert.True(t, invalid)
 	assert.True(t, invalid)
 }
 }
+
+type dummyCert struct {
+	version        cert.Version
+	curve          cert.Curve
+	groups         []string
+	isCa           bool
+	issuer         string
+	name           string
+	networks       []netip.Prefix
+	notAfter       time.Time
+	notBefore      time.Time
+	publicKey      []byte
+	signature      []byte
+	unsafeNetworks []netip.Prefix
+}
+
+func (d *dummyCert) Version() cert.Version {
+	return d.version
+}
+
+func (d *dummyCert) Curve() cert.Curve {
+	return d.curve
+}
+
+func (d *dummyCert) Groups() []string {
+	return d.groups
+}
+
+func (d *dummyCert) IsCA() bool {
+	return d.isCa
+}
+
+func (d *dummyCert) Issuer() string {
+	return d.issuer
+}
+
+func (d *dummyCert) Name() string {
+	return d.name
+}
+
+func (d *dummyCert) Networks() []netip.Prefix {
+	return d.networks
+}
+
+func (d *dummyCert) NotAfter() time.Time {
+	return d.notAfter
+}
+
+func (d *dummyCert) NotBefore() time.Time {
+	return d.notBefore
+}
+
+func (d *dummyCert) PublicKey() []byte {
+	return d.publicKey
+}
+
+func (d *dummyCert) Signature() []byte {
+	return d.signature
+}
+
+func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
+	return d.unsafeNetworks
+}
+
+func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
+	return nil
+}
+
+func (d *dummyCert) CheckSignature(key []byte) bool {
+	return true
+}
+
+func (d *dummyCert) Expired(t time.Time) bool {
+	return false
+}
+
+func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
+	return nil
+}
+
+func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
+	return nil
+}
+
+func (d *dummyCert) String() string {
+	return ""
+}
+
+func (d *dummyCert) Marshal() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) MarshalPEM() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Fingerprint() (string, error) {
+	return "", nil
+}
+
+func (d *dummyCert) MarshalJSON() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Copy() cert.Certificate {
+	return d
+}

+ 32 - 23
connection_state.go

@@ -3,6 +3,7 @@ package nebula
 import (
 import (
 	"crypto/rand"
 	"crypto/rand"
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 
 
@@ -18,50 +19,54 @@ type ConnectionState struct {
 	eKey           *NebulaCipherState
 	eKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	H              *noise.HandshakeState
 	H              *noise.HandshakeState
-	myCert         *cert.NebulaCertificate
-	peerCert       *cert.NebulaCertificate
+	myCert         cert.Certificate
+	peerCert       *cert.CachedCertificate
 	initiator      bool
 	initiator      bool
 	messageCounter atomic.Uint64
 	messageCounter atomic.Uint64
 	window         *Bits
 	window         *Bits
 	writeLock      sync.Mutex
 	writeLock      sync.Mutex
 }
 }
 
 
-func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
 	var dhFunc noise.DHFunc
 	var dhFunc noise.DHFunc
-	switch certState.Certificate.Details.Curve {
+	switch crt.Curve() {
 	case cert.Curve_CURVE25519:
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
 	case cert.Curve_P256:
-		dhFunc = noiseutil.DHP256
+		if cs.pkcs11Backed {
+			dhFunc = noiseutil.DHP256PKCS11
+		} else {
+			dhFunc = noiseutil.DHP256
+		}
 	default:
 	default:
-		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
-		return nil
+		return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
 	}
 	}
 
 
-	var cs noise.CipherSuite
-	if cipher == "chachapoly" {
-		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	var ncs noise.CipherSuite
+	if cs.cipher == "chachapoly" {
+		ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	} else {
 	} else {
-		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
+		ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 	}
 
 
-	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
+	static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
 
 
 	b := NewBits(ReplayWindow)
 	b := NewBits(ReplayWindow)
-	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
+	// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
 	b.Update(l, 0)
 	b.Update(l, 0)
 
 
 	hs, err := noise.NewHandshakeState(noise.Config{
 	hs, err := noise.NewHandshakeState(noise.Config{
-		CipherSuite:           cs,
-		Random:                rand.Reader,
-		Pattern:               pattern,
-		Initiator:             initiator,
-		StaticKeypair:         static,
-		PresharedKey:          psk,
-		PresharedKeyPlacement: pskStage,
+		CipherSuite:   ncs,
+		Random:        rand.Reader,
+		Pattern:       pattern,
+		Initiator:     initiator,
+		StaticKeypair: static,
+		//NOTE: These should come from CertState (pki.go) when we finally implement it
+		PresharedKey:          []byte{},
+		PresharedKeyPlacement: 0,
 	})
 	})
 	if err != nil {
 	if err != nil {
-		return nil
+		return nil, fmt.Errorf("NewConnectionState: %s", err)
 	}
 	}
 
 
 	// The queue and ready params prevent a counter race that would happen when
 	// The queue and ready params prevent a counter race that would happen when
@@ -70,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		H:         hs,
 		initiator: initiator,
 		initiator: initiator,
 		window:    b,
 		window:    b,
-		myCert:    certState.Certificate,
+		myCert:    crt,
 	}
 	}
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
 	ci.messageCounter.Add(2)
 	ci.messageCounter.Add(2)
 
 
-	return ci
+	return ci, nil
 }
 }
 
 
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@@ -85,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"message_counter": cs.messageCounter.Load(),
 		"message_counter": cs.messageCounter.Load(),
 	})
 	})
 }
 }
+
+func (cs *ConnectionState) Curve() cert.Curve {
+	return cs.myCert.Curve()
+}

+ 37 - 36
control.go

@@ -19,9 +19,9 @@ import (
 type controlEach func(h *HostInfo)
 type controlEach func(h *HostInfo)
 
 
 type controlHostLister interface {
 type controlHostLister interface {
-	QueryVpnIp(vpnIp netip.Addr) *HostInfo
+	QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
 	ForEachIndex(each controlEach)
-	ForEachVpnIp(each controlEach)
+	ForEachVpnAddr(each controlEach)
 	GetPreferredRanges() []netip.Prefix
 	GetPreferredRanges() []netip.Prefix
 }
 }
 
 
@@ -37,15 +37,15 @@ type Control struct {
 }
 }
 
 
 type ControlHostInfo struct {
 type ControlHostInfo struct {
-	VpnIp                  netip.Addr              `json:"vpnIp"`
-	LocalIndex             uint32                  `json:"localIndex"`
-	RemoteIndex            uint32                  `json:"remoteIndex"`
-	RemoteAddrs            []netip.AddrPort        `json:"remoteAddrs"`
-	Cert                   *cert.NebulaCertificate `json:"cert"`
-	MessageCounter         uint64                  `json:"messageCounter"`
-	CurrentRemote          netip.AddrPort          `json:"currentRemote"`
-	CurrentRelaysToMe      []netip.Addr            `json:"currentRelaysToMe"`
-	CurrentRelaysThroughMe []netip.Addr            `json:"currentRelaysThroughMe"`
+	VpnAddrs               []netip.Addr     `json:"vpnAddrs"`
+	LocalIndex             uint32           `json:"localIndex"`
+	RemoteIndex            uint32           `json:"remoteIndex"`
+	RemoteAddrs            []netip.AddrPort `json:"remoteAddrs"`
+	Cert                   cert.Certificate `json:"cert"`
+	MessageCounter         uint64           `json:"messageCounter"`
+	CurrentRemote          netip.AddrPort   `json:"currentRemote"`
+	CurrentRelaysToMe      []netip.Addr     `json:"currentRelaysToMe"`
+	CurrentRelaysThroughMe []netip.Addr     `json:"currentRelaysThroughMe"`
 }
 }
 
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -130,15 +130,18 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 }
 }
 
 
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
-func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
-	if c.f.myVpnNet.Addr() == vpnIp {
-		return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
+	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
+	if found {
+		// Only returning the default certificate since its impossible
+		// for any other host but ourselves to have more than 1
+		return c.f.pki.getCertState().GetDefaultCertificate().Copy()
 	}
 	}
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 	if hi == nil {
 		return nil
 		return nil
 	}
 	}
-	return hi.GetCert()
+	return hi.GetCert().Certificate.Copy()
 }
 }
 
 
 // CreateTunnel creates a new tunnel to the given vpn ip.
 // CreateTunnel creates a new tunnel to the given vpn ip.
@@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
 
 
 // PrintTunnel creates a new tunnel to the given vpn ip.
 // PrintTunnel creates a new tunnel to the given vpn ip.
 func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
 func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 	if hi == nil {
 		return nil
 		return nil
 	}
 	}
@@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
 	return hi.CopyCache()
 	return hi.CopyCache()
 }
 }
 
 
-// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
+// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
-func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
+func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	var hl controlHostLister
 	if pending {
 	if pending {
 		hl = c.f.handshakeManager
 		hl = c.f.handshakeManager
@@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 		hl = c.f.hostMap
 		hl = c.f.hostMap
 	}
 	}
 
 
-	h := hl.QueryVpnIp(vpnIp)
+	h := hl.QueryVpnAddr(vpnAddr)
 	if h == nil {
 	if h == nil {
 		return nil
 		return nil
 	}
 	}
@@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
 func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return nil
 		return nil
 	}
 	}
@@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return false
 		return false
 	}
 	}
@@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
 // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
 // the int returned is a count of tunnels closed
 // the int returned is a count of tunnels closed
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
-	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
-	lighthouses := c.f.lightHouse.GetLighthouses()
-
 	shutdown := func(h *HostInfo) {
 	shutdown := func(h *HostInfo) {
-		if excludeLighthouses {
-			if _, ok := lighthouses[h.vpnIp]; ok {
-				return
-			}
+		if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
+			return
 		}
 		}
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.closeTunnel(h)
 		c.f.closeTunnel(h)
 
 
-		c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
+		c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
 			Debug("Sending close tunnel message")
 			Debug("Sending close tunnel message")
 		closed++
 		closed++
 	}
 	}
@@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Relays map
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
 	for _, relayingHost := range c.f.hostMap.Relays {
-		relayingHosts[relayingHost.vpnIp] = relayingHost
+		relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
 	}
 	}
 	c.f.hostMap.Unlock()
 	c.f.hostMap.Unlock()
 
 
@@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Hosts map
 	// Grab the hostMap lock to access the Hosts map
 	c.f.hostMap.Lock()
 	c.f.hostMap.Lock()
 	for _, relayHost := range c.f.hostMap.Indexes {
 	for _, relayHost := range c.f.hostMap.Indexes {
-		if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
+		if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
 			hostInfos = append(hostInfos, relayHost)
 			hostInfos = append(hostInfos, relayHost)
 		}
 		}
 	}
 	}
@@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device {
 }
 }
 
 
 func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
-
 	chi := ControlHostInfo{
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp,
+		VpnAddrs:               make([]netip.Addr, len(h.vpnAddrs)),
 		LocalIndex:             h.localIndexId,
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
@@ -285,12 +282,16 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 		CurrentRemote:          h.remote,
 		CurrentRemote:          h.remote,
 	}
 	}
 
 
+	for i, a := range h.vpnAddrs {
+		chi.VpnAddrs[i] = a
+	}
+
 	if h.ConnectionState != nil {
 	if h.ConnectionState != nil {
 		chi.MessageCounter = h.ConnectionState.messageCounter.Load()
 		chi.MessageCounter = h.ConnectionState.messageCounter.Load()
 	}
 	}
 
 
 	if c := h.GetCert(); c != nil {
 	if c := h.GetCert(); c != nil {
-		chi.Cert = c.Copy()
+		chi.Cert = c.Certificate.Copy()
 	}
 	}
 
 
 	return chi
 	return chi
@@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 	hosts := make([]ControlHostInfo, 0)
 	hosts := make([]ControlHostInfo, 0)
 	pr := hl.GetPreferredRanges()
 	pr := hl.GetPreferredRanges()
-	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+	hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 	})
 	})
 	return hosts
 	return hosts

+ 22 - 36
control_test.go

@@ -5,7 +5,6 @@ import (
 	"net/netip"
 	"net/netip"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
-	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
@@ -14,10 +13,13 @@ import (
 )
 )
 
 
 func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 func TestControl_GetHostInfoByVpnIp(t *testing.T) {
+	//TODO: CERT-V2 with multiple certificate versions we have a problem with this test
+	// Some certs versions have different characteristics and each version implements their own Copy() func
+	// which means this is not a good place to test for exposing memory
 	l := test.NewLogger()
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, netip.Prefix{})
+	hm := newHostMap(l)
 	hm.preferredRanges.Store(&[]netip.Prefix{})
 	hm.preferredRanges.Store(&[]netip.Prefix{})
 
 
 	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
 	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@@ -33,42 +35,27 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 	}
 
 
-	crt := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test",
-			Ips:            []*net.IPNet{&ipNet},
-			Subnets:        []*net.IPNet{},
-			Groups:         []string{"default-group"},
-			NotBefore:      time.Unix(1, 0),
-			NotAfter:       time.Unix(2, 0),
-			PublicKey:      []byte{5, 6, 7, 8},
-			IsCA:           false,
-			Issuer:         "the-issuer",
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-		},
-		Signature: []byte{1, 2, 1, 2, 1, 3},
-	}
-
-	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
-	remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+	remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
 
 
 	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
 	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
 	assert.True(t, ok)
 	assert.True(t, ok)
 
 
+	crt := &dummyCert{}
 	hm.unlockedAddHostInfo(&HostInfo{
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remote:  remote1,
 		remotes: remotes,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
-			peerCert: crt,
+			peerCert: &cert.CachedCertificate{Certificate: crt},
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}, &Interface{})
 	}, &Interface{})
 
 
@@ -83,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		vpnIp:         vpnIp2,
+		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}, &Interface{})
 	}, &Interface{})
 
 
@@ -98,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 		l: logrus.New(),
 	}
 	}
 
 
-	thi := c.GetHostInfoByVpnIp(vpnIp, false)
+	thi := c.GetHostInfoByVpnAddr(vpnIp, false)
 
 
 	expectedInfo := ControlHostInfo{
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  vpnIp,
+		VpnAddrs:               []netip.Addr{vpnIp},
 		LocalIndex:             201,
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteIndex:            200,
 		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
 		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
@@ -113,14 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 	}
 
 
 	// Make sure we don't have any unexpected fields
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	assert.EqualValues(t, &expectedInfo, thi)
 	assert.EqualValues(t, &expectedInfo, thi)
-	//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
-	//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
+	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(vpnIp2, false)
+		thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
 	})
 	})
 }
 }
 
 

+ 43 - 22
control_tester.go

@@ -6,8 +6,6 @@ package nebula
 import (
 import (
 	"net/netip"
 	"net/netip"
 
 
-	"github.com/slackhq/nebula/cert"
-
 	"github.com/google/gopacket"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
@@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
 
 
 	if toAddr.Addr().Is4() {
 	if toAddr.Addr().Is4() {
-		remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
 	} else {
 	} else {
-		remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
 	}
 	}
 }
 }
 
 
@@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
 
 
-	remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
+	remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
 }
 }
 
 
 // GetFromTun will pull a packet off the tun side of nebula
 // GetFromTun will pull a packet off the tun side of nebula
@@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 }
 
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
-	//TODO: IPV6-WORK
-	ip := layers.IPv4{
-		Version:  4,
-		TTL:      64,
-		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().Addr().Unmap().AsSlice(),
-		DstIP:    toIp.Unmap().AsSlice(),
+func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
+	serialize := make([]gopacket.SerializableLayer, 0)
+	var netLayer gopacket.NetworkLayer
+	if toAddr.Is6() {
+		if !fromAddr.Is6() {
+			panic("Cant send ipv6 to ipv4")
+		}
+		ip := &layers.IPv6{
+			Version:    6,
+			NextHeader: layers.IPProtocolUDP,
+			SrcIP:      fromAddr.Unmap().AsSlice(),
+			DstIP:      toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
+	} else {
+		if !fromAddr.Is4() {
+			panic("Cant send ipv4 to ipv6")
+		}
+
+		ip := &layers.IPv4{
+			Version:  4,
+			TTL:      64,
+			Protocol: layers.IPProtocolUDP,
+			SrcIP:    fromAddr.Unmap().AsSlice(),
+			DstIP:    toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
 	}
 	}
 
 
 	udp := layers.UDP{
 	udp := layers.UDP{
 		SrcPort: layers.UDPPort(fromPort),
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 		DstPort: layers.UDPPort(toPort),
 	}
 	}
-	err := udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(netLayer)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 		ComputeChecksums: true,
 		ComputeChecksums: true,
 		FixLengths:       true,
 		FixLengths:       true,
 	}
 	}
-	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+
+	serialize = append(serialize, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, serialize...)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 }
 
 
-func (c *Control) GetVpnIp() netip.Addr {
-	return c.f.myVpnNet.Addr()
+func (c *Control) GetVpnAddrs() []netip.Addr {
+	return c.f.myVpnAddrs
 }
 }
 
 
 func (c *Control) GetUDPAddr() netip.AddrPort {
 func (c *Control) GetUDPAddr() netip.AddrPort {
@@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
 }
 }
 
 
 func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
 func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
+	hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		return false
 		return false
 	}
 	}
@@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 	return c.f.hostMap
 }
 }
 
 
-func (c *Control) GetCert() *cert.NebulaCertificate {
-	return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertState() *CertState {
+	return c.f.pki.getCertState()
 }
 }
 
 
 func (c *Control) ReHandshake(vpnIp netip.Addr) {
 func (c *Control) ReHandshake(vpnIp netip.Addr) {

+ 79 - 39
dns_server.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -21,24 +22,39 @@ var dnsAddr string
 
 
 type dnsRecords struct {
 type dnsRecords struct {
 	sync.RWMutex
 	sync.RWMutex
-	dnsMap  map[string]string
-	hostMap *HostMap
+	l               *logrus.Logger
+	dnsMap4         map[string]netip.Addr
+	dnsMap6         map[string]netip.Addr
+	hostMap         *HostMap
+	myVpnAddrsTable *bart.Table[struct{}]
 }
 }
 
 
-func newDnsRecords(hostMap *HostMap) *dnsRecords {
+func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
 	return &dnsRecords{
 	return &dnsRecords{
-		dnsMap:  make(map[string]string),
-		hostMap: hostMap,
+		l:               l,
+		dnsMap4:         make(map[string]netip.Addr),
+		dnsMap6:         make(map[string]netip.Addr),
+		hostMap:         hostMap,
+		myVpnAddrsTable: cs.myVpnAddrsTable,
 	}
 	}
 }
 }
 
 
-func (d *dnsRecords) Query(data string) string {
+func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
+	data = strings.ToLower(data)
 	d.RLock()
 	d.RLock()
 	defer d.RUnlock()
 	defer d.RUnlock()
-	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
-		return r
+	switch q {
+	case dns.TypeA:
+		if r, ok := d.dnsMap4[data]; ok {
+			return r
+		}
+	case dns.TypeAAAA:
+		if r, ok := d.dnsMap6[data]; ok {
+			return r
+		}
 	}
 	}
-	return ""
+
+	return netip.Addr{}
 }
 }
 
 
 func (d *dnsRecords) QueryCert(data string) string {
 func (d *dnsRecords) QueryCert(data string) string {
@@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 		return ""
 	}
 	}
 
 
-	hostinfo := d.hostMap.QueryVpnIp(ip)
+	hostinfo := d.hostMap.QueryVpnAddr(ip)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		return ""
 		return ""
 	}
 	}
@@ -57,43 +73,69 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 		return ""
 	}
 	}
 
 
-	cert := q.Details
-	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
-	return c
+	b, err := q.Certificate.MarshalJSON()
+	if err != nil {
+		return ""
+	}
+	return string(b)
 }
 }
 
 
-func (d *dnsRecords) Add(host, data string) {
+// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
+func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
+	host = strings.ToLower(host)
 	d.Lock()
 	d.Lock()
 	defer d.Unlock()
 	defer d.Unlock()
-	d.dnsMap[strings.ToLower(host)] = data
+	haveV4 := false
+	haveV6 := false
+	for _, addr := range addresses {
+		if addr.Is4() && !haveV4 {
+			d.dnsMap4[host] = addr
+			haveV4 = true
+		} else if addr.Is6() && !haveV6 {
+			d.dnsMap6[host] = addr
+			haveV6 = true
+		}
+		if haveV4 && haveV6 {
+			break
+		}
+	}
 }
 }
 
 
-func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
+func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
+	a, _, _ := net.SplitHostPort(addr)
+	b, err := netip.ParseAddr(a)
+	if err != nil {
+		return false
+	}
+
+	if b.IsLoopback() {
+		return true
+	}
+
+	_, found := d.myVpnAddrsTable.Lookup(b)
+	return found //if we found it in this table, it's good
+}
+
+func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 	for _, q := range m.Question {
 	for _, q := range m.Question {
 		switch q.Qtype {
 		switch q.Qtype {
-		case dns.TypeA:
-			l.Debugf("Query for A %s", q.Name)
-			ip := dnsR.Query(q.Name)
-			if ip != "" {
-				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
+		case dns.TypeA, dns.TypeAAAA:
+			qType := dns.TypeToString[q.Qtype]
+			d.l.Debugf("Query for %s %s", qType, q.Name)
+			ip := d.Query(q.Qtype, q.Name)
+			if ip.IsValid() {
+				rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
 				if err == nil {
 				if err == nil {
 					m.Answer = append(m.Answer, rr)
 					m.Answer = append(m.Answer, rr)
 				}
 				}
 			}
 			}
 		case dns.TypeTXT:
 		case dns.TypeTXT:
-			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b, err := netip.ParseAddr(a)
-			if err != nil {
+			// We only answer these queries from nebula nodes or localhost
+			if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
 				return
 				return
 			}
 			}
-
-			// We don't answer these queries from non nebula nodes or localhost
-			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
-			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
-				return
-			}
-			l.Debugf("Query for TXT %s", q.Name)
-			ip := dnsR.QueryCert(q.Name)
+			d.l.Debugf("Query for TXT %s", q.Name)
+			ip := d.QueryCert(q.Name)
 			if ip != "" {
 			if ip != "" {
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				if err == nil {
 				if err == nil {
@@ -108,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 	}
 	}
 }
 }
 
 
-func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
+func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.SetReply(r)
 	m.Compress = false
 	m.Compress = false
 
 
 	switch r.Opcode {
 	switch r.Opcode {
 	case dns.OpcodeQuery:
 	case dns.OpcodeQuery:
-		parseQuery(l, m, w)
+		d.parseQuery(m, w)
 	}
 	}
 
 
 	w.WriteMsg(m)
 	w.WriteMsg(m)
 }
 }
 
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
-	dnsR = newDnsRecords(hostMap)
+func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
+	dnsR = newDnsRecords(l, cs, hostMap)
 
 
 	// attach request handler func
 	// attach request handler func
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		handleDnsRequest(l, w, r)
-	})
+	dns.HandleFunc(".", dnsR.handleDnsRequest)
 
 
 	c.RegisterReloadCallback(func(c *config.C) {
 	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)
 		reloadDns(l, c)

+ 20 - 5
dns_server_test.go

@@ -1,23 +1,38 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"net/netip"
 	"testing"
 	"testing"
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestParsequery(t *testing.T) {
 func TestParsequery(t *testing.T) {
-	//TODO: This test is basically pointless
+	l := logrus.New()
 	hostMap := &HostMap{}
 	hostMap := &HostMap{}
-	ds := newDnsRecords(hostMap)
-	ds.Add("test.com.com", "1.2.3.4")
+	ds := newDnsRecords(l, &CertState{}, hostMap)
+	addrs := []netip.Addr{
+		netip.MustParseAddr("1.2.3.4"),
+		netip.MustParseAddr("1.2.3.5"),
+		netip.MustParseAddr("fd01::24"),
+		netip.MustParseAddr("fd01::25"),
+	}
+	ds.Add("test.com.com", addrs)
 
 
-	m := new(dns.Msg)
+	m := &dns.Msg{}
 	m.SetQuestion("test.com.com", dns.TypeA)
 	m.SetQuestion("test.com.com", dns.TypeA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
 
 
-	//parseQuery(m)
+	m = &dns.Msg{}
+	m.SetQuestion("test.com.com", dns.TypeAAAA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
 }
 }
 
 
 func Test_getDnsServerAddr(t *testing.T) {
 func Test_getDnsServerAddr(t *testing.T) {

File diff suppressed because it is too large
+ 389 - 168
e2e/handshakes_test.go


+ 0 - 125
e2e/helpers.go

@@ -1,125 +0,0 @@
-package e2e
-
-import (
-	"crypto/rand"
-	"io"
-	"net"
-	"net/netip"
-	"time"
-
-	"github.com/slackhq/nebula/cert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-)
-
-// NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = make([]*net.IPNet, len(ips))
-		for i, ip := range ips {
-			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
-		}
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = make([]*net.IPNet, len(subnets))
-		for i, ip := range subnets {
-			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
-		}
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(cert.Curve_CURVE25519, priv)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, priv, pem
-}
-
-// NewTestCert will generate a signed certificate with the provided details.
-// Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		panic(err)
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	pub, rawPriv := x25519Keypair()
-	ipb := ip.Addr().AsSlice()
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name: name,
-			Ips:  []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
-			//Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}

+ 75 - 38
e2e/helpers_test.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"io"
 	"net/netip"
 	"net/netip"
 	"os"
 	"os"
+	"strings"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -17,6 +18,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
@@ -26,27 +28,37 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 // newSimpleServer creates a nebula instance with many assumptions
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
+func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 	l := NewTestLogger()
 
 
-	vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
-	if err != nil {
-		panic(err)
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
 	}
 	}
 
 
 	var udpAddr netip.AddrPort
 	var udpAddr netip.AddrPort
-	if vpnIpNet.Addr().Is4() {
-		budpIp := vpnIpNet.Addr().As4()
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
 		budpIp[1] -= 128
 		budpIp[1] -= 128
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
 	} else {
 	} else {
-		budpIp := vpnIpNet.Addr().As16()
-		budpIp[13] -= 128
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
 	}
-	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
 
-	caB, err := caCrt.MarshalToPEM()
+	caB, err := caCrt.MarshalPEM()
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -88,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
 	}
 	}
 
 
 	if overrides != nil {
 	if overrides != nil {
-		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		final := m{}
+		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
 		if err != nil {
 		if err != nil {
 			panic(err)
 			panic(err)
 		}
 		}
-		mc = overrides
+		mc = final
 	}
 	}
 
 
 	cb, err := yaml.Marshal(mc)
 	cb, err := yaml.Marshal(mc)
@@ -109,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	return control, vpnIpNet, udpAddr, c
+	return control, vpnNetworks, udpAddr, c
 }
 }
 
 
 type doneCb func()
 type doneCb func()
@@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 
 
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	// Send a packet from them to me
-	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
+	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 
 	// And once more from me to them
 	// And once more from me to them
-	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
+	controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 }
 
 
-func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
 	// Get both host infos
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
+	//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
+	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
 
-	hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
+	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 
 	// Check that both vpn and real addr are correct
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
+	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
+	assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
 
 
 	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
 	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
 	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
 	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
 	// Check that our indexes match
 	// Check that our indexes match
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
-
-	//TODO: Would be nice to assert this memory
-	//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
-	//	hBbyIndex := hmA.Indexes[hBinA.localIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
-	//	assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
-	//
-	//	//TODO: remote indexes are susceptible to collision
-	//	hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
-	//	assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
-	//}
-	//
-	//// Check hostmap indexes too
-	//checkIndexes("hmA", hmA, hBinA)
-	//checkIndexes("hmB", hmB, hAinB)
 }
 }
 
 
 func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	if toIp.Is6() {
+		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
+	} else {
+		assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
+	}
+}
+
+func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
+	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+	assert.NotNil(t, v6, "No ipv6 data found")
+
+	assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
+
+	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	assert.NotNil(t, udp, "No udp data found")
+
+	assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
+	assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
+
+	data := packet.ApplicationLayer()
+	assert.NotNil(t, data)
+	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
+}
+
+func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
 	assert.NotNil(t, v4, "No ipv4 data found")
@@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 }
 
 
+func getAddrs(ns []netip.Prefix) []netip.Addr {
+	var a []netip.Addr
+	for _, n := range ns {
+		a = append(a, n.Addr())
+	}
+	return a
+}
+
 func NewTestLogger() *logrus.Logger {
 func NewTestLogger() *logrus.Logger {
 	l := logrus.New()
 	l := logrus.New()
 
 

+ 5 - 4
e2e/router/hostmap.go

@@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	var lines []string
 	var lines []string
 	var globalLines []*edge
 	var globalLines []*edge
 
 
-	clusterName := strings.Trim(c.GetCert().Details.Name, " ")
-	clusterVpnIp := c.GetCert().Details.Ips[0].IP
+	crt := c.GetCertState().GetDefaultCertificate()
+	clusterName := strings.Trim(crt.Name(), " ")
+	clusterVpnIp := crt.Networks()[0].Addr()
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 
 
 	hm := c.GetHostmap()
 	hm := c.GetHostmap()
@@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	for _, idx := range indexes {
 	for _, idx := range indexes {
 		hi, ok := hm.Indexes[idx]
 		hi, ok := hm.Indexes[idx]
 		if ok {
 		if ok {
-			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
-			remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
+			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
+			remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			_ = hi
 			_ = hi
 		}
 		}

+ 44 - 22
e2e/router/router.go

@@ -10,8 +10,8 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"reflect"
 	"reflect"
+	"regexp"
 	"sort"
 	"sort"
-	"strings"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 			panic("Duplicate listen address: " + addr.String())
 			panic("Duplicate listen address: " + addr.String())
 		}
 		}
 
 
-		r.vpnControls[c.GetVpnIp()] = c
+		for _, vpnAddr := range c.GetVpnAddrs() {
+			r.vpnControls[vpnAddr] = c
+		}
+
 		r.controls[addr] = c
 		r.controls[addr] = c
 	}
 	}
 
 
@@ -213,11 +216,11 @@ func (r *R) renderFlow() {
 			continue
 			continue
 		}
 		}
 		participants[addr] = struct{}{}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr.String(), ":", "-", 1)
+		sanAddr := normalizeName(addr.String())
 		participantsVals = append(participantsVals, sanAddr)
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
-			sanAddr, e.packet.from.GetVpnIp(), sanAddr,
+			sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
 		)
 		)
 	}
 	}
 
 
@@ -250,9 +253,9 @@ func (r *R) renderFlow() {
 
 
 			fmt.Fprintf(f,
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.from.GetUDPAddr().String()),
 				line,
 				line,
-				strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.to.GetUDPAddr().String()),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 			)
 		}
 		}
@@ -267,6 +270,11 @@ func (r *R) renderFlow() {
 	}
 	}
 }
 }
 
 
+func normalizeName(s string) string {
+	rx := regexp.MustCompile("[\\[\\]\\:]")
+	return rx.ReplaceAllLiteralString(s, "_")
+}
+
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
+		return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
 	})
 	})
 
 
 	s := renderHostmaps(c...)
 	s := renderHostmaps(c...)
@@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 		// Nope, lets push the sender along
 		// Nope, lets push the sender along
 		case p := <-udpTx:
 		case p := <-udpTx:
 			r.Lock()
 			r.Lock()
-			c := r.getControl(sender.GetUDPAddr(), p.To, p)
+			a := sender.GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 			if c == nil {
 				r.Unlock()
 				r.Unlock()
-				panic("No control for udp tx")
+				panic("No control for udp tx " + a.String())
 			}
 			}
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			c.InjectUDPPacket(p)
 			c.InjectUDPPacket(p)
@@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 		} else {
 			// we are a udp tx, route and continue
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
 			p := rx.Interface().(*udp.Packet)
-			c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
+			a := cm[x].GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 			if c == nil {
 				r.Unlock()
 				r.Unlock()
-				panic("No control for udp tx")
+				panic(fmt.Sprintf("No control for udp tx %s", p.To))
 			}
 			}
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			c.InjectUDPPacket(p)
 			c.InjectUDPPacket(p)
@@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
 }
 }
 
 
 func (r *R) formatUdpPacket(p *packet) string {
 func (r *R) formatUdpPacket(p *packet) string {
-	packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
-	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
-	if v4 == nil {
-		panic("not an ipv4 packet")
+	var packet gopacket.Packet
+	var srcAddr netip.Addr
+
+	packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
+	if packet.ErrorLayer() == nil {
+		v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
+	} else {
+		packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
+		v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
 	}
 	}
 
 
 	from := "unknown"
 	from := "unknown"
-	srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
 	if c, ok := r.vpnControls[srcAddr]; ok {
 	if c, ok := r.vpnControls[srcAddr]; ok {
 		from = c.GetUDPAddr().String()
 		from = c.GetUDPAddr().String()
 	}
 	}
 
 
-	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
-	if udp == nil {
+	udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	if udpLayer == nil {
 		panic("not a udp packet")
 		panic("not a udp packet")
 	}
 	}
 
 
 	data := packet.ApplicationLayer()
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
-		udp.SrcPort,
-		udp.DstPort,
+		normalizeName(from),
+		normalizeName(p.to.GetUDPAddr().String()),
+		udpLayer.SrcPort,
+		udpLayer.DstPort,
 		string(data.Payload()),
 		string(data.Payload()),
 	)
 	)
 }
 }

+ 12 - 5
examples/config.yml

@@ -13,6 +13,12 @@ pki:
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   #disconnect_invalid: true
   #disconnect_invalid: true
 
 
+  # default_version controls which certificate version is used in handshakes.
+  # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
+  # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
+  # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
+  # default_version: 1
+
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # The syntax is:
 # The syntax is:
@@ -285,7 +291,6 @@ tun:
     # send multiport handshakes.
     # send multiport handshakes.
     #tx_handshake_delay: 2
     #tx_handshake_delay: 2
 
 
-# TODO
 # Configure logging level
 # Configure logging level
 logging:
 logging:
   # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
   # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
@@ -377,10 +382,12 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
-  #      Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
-  #      if `default_local_cidr_any` is false, otherwise its `any`.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
+  #     If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
+  #     Otherwise the default is any vpn network assigned to via the certificate.
+  #     `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
+  #     If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
   #   ca_name: An issuing CA name
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
   #   ca_sha: An issuing CA shasum
 
 

+ 96 - 94
firewall.go

@@ -22,7 +22,7 @@ import (
 )
 )
 
 
 type FirewallInterface interface {
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
 }
 }
 
 
 type conn struct {
 type conn struct {
@@ -51,10 +51,13 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 	DefaultTimeout time.Duration //linux: 600s
 
 
-	// Used to ensure we don't emit local packets for ips we don't own
-	localIps     *bart.Table[struct{}]
-	assignedCIDR netip.Prefix
-	hasSubnets   bool
+	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
+	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
+	routableNetworks *bart.Table[struct{}]
+
+	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
+	assignedNetworks  []netip.Prefix
+	hasUnsafeNetworks bool
 
 
 	rules        string
 	rules        string
 	rulesVersion uint16
 	rulesVersion uint16
@@ -67,9 +70,9 @@ type Firewall struct {
 }
 }
 
 
 type firewallMetrics struct {
 type firewallMetrics struct {
-	droppedLocalIP  metrics.Counter
-	droppedRemoteIP metrics.Counter
-	droppedNoRule   metrics.Counter
+	droppedLocalAddr  metrics.Counter
+	droppedRemoteAddr metrics.Counter
+	droppedNoRule     metrics.Counter
 }
 }
 
 
 type FirewallConntrack struct {
 type FirewallConntrack struct {
@@ -126,88 +129,87 @@ type firewallLocalCIDR struct {
 }
 }
 
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
-func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
+// The certificate provided should be the highest version loaded in memory.
+func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
 	//TODO: error on 0 duration
 	//TODO: error on 0 duration
-	var min, max time.Duration
+	var tmin, tmax time.Duration
 
 
 	if tcpTimeout < UDPTimeout {
 	if tcpTimeout < UDPTimeout {
-		min = tcpTimeout
-		max = UDPTimeout
+		tmin = tcpTimeout
+		tmax = UDPTimeout
 	} else {
 	} else {
-		min = UDPTimeout
-		max = tcpTimeout
+		tmin = UDPTimeout
+		tmax = tcpTimeout
 	}
 	}
 
 
-	if defaultTimeout < min {
-		min = defaultTimeout
-	} else if defaultTimeout > max {
-		max = defaultTimeout
+	if defaultTimeout < tmin {
+		tmin = defaultTimeout
+	} else if defaultTimeout > tmax {
+		tmax = defaultTimeout
 	}
 	}
 
 
-	localIps := new(bart.Table[struct{}])
-	var assignedCIDR netip.Prefix
-	var assignedSet bool
-	for _, ip := range c.Details.Ips {
-		//TODO: IPV6-WORK the unmap is a bit unfortunate
-		nip, _ := netip.AddrFromSlice(ip.IP)
-		nip = nip.Unmap()
-		nprefix := netip.PrefixFrom(nip, nip.BitLen())
-		localIps.Insert(nprefix, struct{}{})
-
-		if !assignedSet {
-			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = nprefix
-			assignedSet = true
-		}
+	routableNetworks := new(bart.Table[struct{}])
+	var assignedNetworks []netip.Prefix
+	for _, network := range c.Networks() {
+		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
+		routableNetworks.Insert(nprefix, struct{}{})
+		assignedNetworks = append(assignedNetworks, network)
 	}
 	}
 
 
-	for _, n := range c.Details.Subnets {
-		nip, _ := netip.AddrFromSlice(n.IP)
-		ones, _ := n.Mask.Size()
-		nip = nip.Unmap()
-		localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
+	hasUnsafeNetworks := false
+	for _, n := range c.UnsafeNetworks() {
+		routableNetworks.Insert(n, struct{}{})
+		hasUnsafeNetworks = true
 	}
 	}
 
 
 	return &Firewall{
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 		Conntrack: &FirewallConntrack{
 			Conns:      make(map[firewall.Packet]*conn),
 			Conns:      make(map[firewall.Packet]*conn),
-			TimerWheel: NewTimerWheel[firewall.Packet](min, max),
+			TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
 		},
 		},
-		InRules:        newFirewallTable(),
-		OutRules:       newFirewallTable(),
-		TCPTimeout:     tcpTimeout,
-		UDPTimeout:     UDPTimeout,
-		DefaultTimeout: defaultTimeout,
-		localIps:       localIps,
-		assignedCIDR:   assignedCIDR,
-		hasSubnets:     len(c.Details.Subnets) > 0,
-		l:              l,
+		InRules:           newFirewallTable(),
+		OutRules:          newFirewallTable(),
+		TCPTimeout:        tcpTimeout,
+		UDPTimeout:        UDPTimeout,
+		DefaultTimeout:    defaultTimeout,
+		routableNetworks:  routableNetworks,
+		assignedNetworks:  assignedNetworks,
+		hasUnsafeNetworks: hasUnsafeNetworks,
+		l:                 l,
 
 
 		incomingMetrics: firewallMetrics{
 		incomingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
 		},
 		},
 		outgoingMetrics: firewallMetrics{
 		outgoingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
 		},
 		},
 	}
 	}
 }
 }
 
 
-func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
+	certificate := cs.getCertificate(cert.Version2)
+	if certificate == nil {
+		certificate = cs.getCertificate(cert.Version1)
+	}
+
+	if certificate == nil {
+		panic("No certificate available to reconfigure the firewall")
+	}
+
 	fw := NewFirewall(
 	fw := NewFirewall(
 		l,
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
-		nc,
+		certificate,
 		//TODO: max_connections
 		//TODO: max_connections
 	)
 	)
 
 
-	//TODO: Flip to false after v1.9 release
-	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
+	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
 
 
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	switch inboundAction {
 	switch inboundAction {
@@ -287,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		fp = ft.TCP
 		fp = ft.TCP
 	case firewall.ProtoUDP:
 	case firewall.ProtoUDP:
 		fp = ft.UDP
 		fp = ft.UDP
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		fp = ft.ICMP
 		fp = ft.ICMP
 	case firewall.ProtoAny:
 	case firewall.ProtoAny:
 		fp = ft.AnyProto
 		fp = ft.AnyProto
@@ -421,33 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
+func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(fp, h, caPool, localCache) {
 	if f.inConns(fp, h, caPool, localCache) {
 		return nil
 		return nil
 	}
 	}
 
 
 	// Make sure remote address matches nebula certificate
 	// Make sure remote address matches nebula certificate
-	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-		_, ok := remoteCidr.Lookup(fp.RemoteIP)
+	if h.networks != nil {
+		_, ok := h.networks.Lookup(fp.RemoteAddr)
 		if !ok {
 		if !ok {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 			return ErrInvalidRemoteIP
 		}
 		}
 	} else {
 	} else {
-		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.vpnIp {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 			return ErrInvalidRemoteIP
 		}
 		}
 	}
 	}
 
 
 	// Make sure we are supposed to be handling this local ip address
 	// Make sure we are supposed to be handling this local ip address
-	//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-	_, ok := f.localIps.Lookup(fp.LocalIP)
+	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
 	if !ok {
 	if !ok {
-		f.metrics(incoming).droppedLocalIP.Inc(1)
+		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 		return ErrInvalidLocalIP
 	}
 	}
 
 
@@ -492,7 +492,7 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 }
 
 
-func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
+func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
 	if localCache != nil {
 	if localCache != nil {
 		if _, ok := localCache[fp]; ok {
 		if _, ok := localCache[fp]; ok {
 			return true
 			return true
@@ -619,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) {
 	delete(conntrack.Conns, p)
 	delete(conntrack.Conns, p)
 }
 }
 
 
-func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 		return true
 		return true
 	}
 	}
@@ -633,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 		if ft.UDP.match(p, incoming, c, caPool) {
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
@@ -663,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
 	return nil
 	return nil
 }
 }
 
 
-func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	// We don't have any allowed ports, bail
 	// We don't have any allowed ports, bail
 	if fp == nil {
 	if fp == nil {
 		return false
 		return false
@@ -726,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
 	return nil
 	return nil
 }
 }
 
 
-func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	if fc == nil {
 	if fc == nil {
 		return false
 		return false
 	}
 	}
@@ -735,18 +735,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 		return true
 		return true
 	}
 	}
 
 
-	if t, ok := fc.CAShas[c.Details.Issuer]; ok {
+	if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok {
 		if t.match(p, c) {
 		if t.match(p, c) {
 			return true
 			return true
 		}
 		}
 	}
 	}
 
 
-	s, err := caPool.GetCAForCert(c)
+	s, err := caPool.GetCAForCert(c.Certificate)
 	if err != nil {
 	if err != nil {
 		return false
 		return false
 	}
 	}
 
 
-	return fc.CANames[s.Details.Name].match(p, c)
+	return fc.CANames[s.Certificate.Name()].match(p, c)
 }
 }
 
 
 func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
 func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
@@ -826,7 +826,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
 	return false
 	return false
 }
 }
 
 
-func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool {
 	if fr == nil {
 	if fr == nil {
 		return false
 		return false
 	}
 	}
@@ -841,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		found := false
 		found := false
 
 
 		for _, g := range sg.Groups {
 		for _, g := range sg.Groups {
-			if _, ok := c.Details.InvertedGroups[g]; !ok {
+			if _, ok := c.InvertedGroups[g]; !ok {
 				found = false
 				found = false
 				break
 				break
 			}
 			}
@@ -855,42 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 	}
 	}
 
 
 	if fr.Hosts != nil {
 	if fr.Hosts != nil {
-		if flc, ok := fr.Hosts[c.Details.Name]; ok {
+		if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
 			if flc.match(p, c) {
 			if flc.match(p, c) {
 				return true
 				return true
 			}
 			}
 		}
 		}
 	}
 	}
 
 
-	matched := false
-	prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
-	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
-		if prefix.Contains(p.RemoteIP) && val.match(p, c) {
-			matched = true
-			return false
+	for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
+		if v.match(p, c) {
+			return true
 		}
 		}
-		return true
-	})
-	return matched
+	}
+
+	return false
 }
 }
 
 
 func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 	if !localIp.IsValid() {
 	if !localIp.IsValid() {
-		if !f.hasSubnets || f.defaultLocalCIDRAny {
+		if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
 			flc.Any = true
 			flc.Any = true
 			return nil
 			return nil
 		}
 		}
 
 
-		localIp = f.assignedCIDR
+		for _, network := range f.assignedNetworks {
+			flc.LocalCIDR.Insert(network, struct{}{})
+		}
+		return nil
+
 	} else if localIp.Bits() == 0 {
 	} else if localIp.Bits() == 0 {
 		flc.Any = true
 		flc.Any = true
+		return nil
 	}
 	}
 
 
 	flc.LocalCIDR.Insert(localIp, struct{}{})
 	flc.LocalCIDR.Insert(localIp, struct{}{})
 	return nil
 	return nil
 }
 }
 
 
-func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
 	if flc == nil {
 	if flc == nil {
 		return false
 		return false
 	}
 	}
@@ -899,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
 		return true
 		return true
 	}
 	}
 
 
-	_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
 	return ok
 	return ok
 }
 }
 
 

+ 11 - 10
firewall/packet.go

@@ -10,18 +10,19 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 const (
 const (
-	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	ProtoTCP  = 6
-	ProtoUDP  = 17
-	ProtoICMP = 1
+	ProtoAny    = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP    = 6
+	ProtoUDP    = 17
+	ProtoICMP   = 1
+	ProtoICMPv6 = 58
 
 
 	PortAny      = 0  // Special value for matching `port: any`
 	PortAny      = 0  // Special value for matching `port: any`
 	PortFragment = -1 // Special value for matching `port: fragment`
 	PortFragment = -1 // Special value for matching `port: fragment`
 )
 )
 
 
 type Packet struct {
 type Packet struct {
-	LocalIP    netip.Addr
-	RemoteIP   netip.Addr
+	LocalAddr  netip.Addr
+	RemoteAddr netip.Addr
 	LocalPort  uint16
 	LocalPort  uint16
 	RemotePort uint16
 	RemotePort uint16
 	Protocol   uint8
 	Protocol   uint8
@@ -30,8 +31,8 @@ type Packet struct {
 
 
 func (fp *Packet) Copy() *Packet {
 func (fp *Packet) Copy() *Packet {
 	return &Packet{
 	return &Packet{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
+		LocalAddr:  fp.LocalAddr,
+		RemoteAddr: fp.RemoteAddr,
 		LocalPort:  fp.LocalPort,
 		LocalPort:  fp.LocalPort,
 		RemotePort: fp.RemotePort,
 		RemotePort: fp.RemotePort,
 		Protocol:   fp.Protocol,
 		Protocol:   fp.Protocol,
@@ -52,8 +53,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 	}
 	}
 	return json.Marshal(m{
 	return json.Marshal(m{
-		"LocalIP":    fp.LocalIP.String(),
-		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalAddr":  fp.LocalAddr.String(),
+		"RemoteAddr": fp.RemoteAddr.String(),
 		"LocalPort":  fp.LocalPort,
 		"LocalPort":  fp.LocalPort,
 		"RemotePort": fp.RemotePort,
 		"RemotePort": fp.RemotePort,
 		"Protocol":   proto,
 		"Protocol":   proto,

+ 136 - 195
firewall_test.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"bytes"
 	"errors"
 	"errors"
 	"math"
 	"math"
-	"net"
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -14,11 +13,12 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 )
 
 
 func TestNewFirewall(t *testing.T) {
 func TestNewFirewall(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	c := &cert.NebulaCertificate{}
+	c := &dummyCert{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	conntrack := fw.Conntrack
 	conntrack := fw.Conntrack
 	assert.NotNil(t, conntrack)
 	assert.NotNil(t, conntrack)
@@ -60,7 +60,7 @@ func TestFirewall_AddRule(t *testing.T) {
 	ob := &bytes.Buffer{}
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
-	c := &cert.NebulaCertificate{}
+	c := &dummyCert{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
 	assert.NotNil(t, fw.OutRules)
@@ -129,35 +129,30 @@ func TestFirewall_Drop(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 		Fragment:   false,
 	}
 	}
 
 
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			Groups:         []string{"default-group"},
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-			Issuer:         "signer-shasum",
-		},
+	c := dummyCert{
+		name:     "host1",
+		networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
 	}
 	}
 	h := HostInfo{
 	h := HostInfo{
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
-			peerCert: &c,
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
 		},
 		},
-		vpnIp: netip.MustParseAddr("1.2.3.4"),
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.networks, c.unsafeNetworks)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -172,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 
 	// test remote mismatch
 	// test remote mismatch
-	oldRemote := p.RemoteIP
-	p.RemoteIP = netip.MustParseAddr("1.2.3.10")
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
-	p.RemoteIP = oldRemote
+	p.RemoteAddr = oldRemote
 
 
 	// ensure signer doesn't get in the way of group checks
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@@ -190,14 +185,14 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 
 
 	// ensure ca name doesn't get in the way of group checks
 	// ensure ca name doesn't get in the way of group checks
-	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caName doesn't drop on match
 	// test caName doesn't drop on match
-	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
@@ -217,7 +212,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 
 
 	b.Run("fail on proto", func(b *testing.B) {
 	b.Run("fail on proto", func(b *testing.B) {
 		// This benchmark is showing us the cost of failing to match the protocol
 		// This benchmark is showing us the cost of failing to match the protocol
-		c := &cert.NebulaCertificate{}
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
 		}
 		}
@@ -225,28 +222,31 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 
 
 	b.Run("pass proto, fail on port", func(b *testing.B) {
 	b.Run("pass proto, fail on port", func(b *testing.B) {
 		// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
 		// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
-		c := &cert.NebulaCertificate{}
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
 	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
 	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
-		c := &cert.NebulaCertificate{}
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
 		ip := netip.MustParsePrefix("9.254.254.254/32")
 		ip := netip.MustParsePrefix("9.254.254.254/32")
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
 	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
-		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "nope",
-				Ips:            []*net.IPNet{ip},
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
 			},
 			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
@@ -254,25 +254,24 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	})
 	})
 
 
 	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
-		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "nope",
-				Ips:            []*net.IPNet{ip},
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
 			},
 			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
 	b.Run("pass on group on any local cidr", func(b *testing.B) {
 	b.Run("pass on group on any local cidr", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"good-group": {}},
-				Name:           "nope",
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
 			},
 			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
 			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
@@ -280,82 +279,28 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	})
 	})
 
 
 	b.Run("pass on group on specific local cidr", func(b *testing.B) {
 	b.Run("pass on group on specific local cidr", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"good-group": {}},
-				Name:           "nope",
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
 			},
 			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
 	b.Run("pass on name", func(b *testing.B) {
 	b.Run("pass on name", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "good-host",
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "good-host",
 			},
 			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
 			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 		}
 	})
 	})
-	//
-	//b.Run("pass on ip", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
-	//	}
-	//})
-	//
-	//b.Run("pass on local ip", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
-	//	}
-	//})
-	//
-	//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
-	//
-	//b.Run("pass on ip with any port", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
-	//	}
-	//})
-	//
-	//b.Run("pass on local ip with any port", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
-	//	}
-	//})
 }
 }
 
 
 func TestFirewall_Drop2(t *testing.T) {
 func TestFirewall_Drop2(t *testing.T) {
@@ -364,49 +309,47 @@ func TestFirewall_Drop2(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 		Fragment:   false,
 	}
 	}
 
 
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
+	network := netip.MustParsePrefix("1.2.3.4/24")
 
 
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
 		},
 		},
+		InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
 	}
 	}
 	h := HostInfo{
 	h := HostInfo{
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 
-	c1 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
+	c1 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
 		},
 		},
+		InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
 	}
 	}
 	h1 := HostInfo{
 	h1 := HostInfo{
+		vpnAddrs: []netip.Addr{network.Addr()},
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 			peerCert: &c1,
 		},
 		},
 	}
 	}
-	h1.CreateRemoteCIDR(&c1)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 
-	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
@@ -423,72 +366,68 @@ func TestFirewall_Drop3(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  1,
 		LocalPort:  1,
 		RemotePort: 1,
 		RemotePort: 1,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 		Fragment:   false,
 	}
 	}
 
 
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name: "host-owner",
-			Ips:  []*net.IPNet{&ipNet},
+	network := netip.MustParsePrefix("1.2.3.4/24")
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{network},
 		},
 		},
 	}
 	}
 
 
-	c1 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:   "host1",
-			Ips:    []*net.IPNet{&ipNet},
-			Issuer: "signer-sha-bad",
+	c1 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
+			issuer:   "signer-sha-bad",
 		},
 		},
 	}
 	}
 	h1 := HostInfo{
 	h1 := HostInfo{
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 			peerCert: &c1,
 		},
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h1.CreateRemoteCIDR(&c1)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 
-	c2 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:   "host2",
-			Ips:    []*net.IPNet{&ipNet},
-			Issuer: "signer-sha",
+	c2 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host2",
+			networks: []netip.Prefix{network},
+			issuer:   "signer-sha",
 		},
 		},
 	}
 	}
 	h2 := HostInfo{
 	h2 := HostInfo{
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 			peerCert: &c2,
 		},
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h2.CreateRemoteCIDR(&c2)
+	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
 
 
-	c3 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:   "host3",
-			Ips:    []*net.IPNet{&ipNet},
-			Issuer: "signer-sha-bad",
+	c3 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host3",
+			networks: []netip.Prefix{network},
+			issuer:   "signer-sha-bad",
 		},
 		},
 	}
 	}
 	h3 := HostInfo{
 	h3 := HostInfo{
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 			peerCert: &c3,
 		},
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h3.CreateRemoteCIDR(&c3)
+	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
 
 
-	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
@@ -501,6 +440,11 @@ func TestFirewall_Drop3(t *testing.T) {
 	// c3 should fail because no match
 	// c3 should fail because no match
 	resetConntrack(fw)
 	resetConntrack(fw)
 	assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
+
+	// Test a remote address match
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
+	assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 }
 }
 
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
 func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -509,37 +453,33 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 		Fragment:   false,
 	}
 	}
-
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			Groups:         []string{"default-group"},
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-			Issuer:         "signer-shasum",
+	network := netip.MustParsePrefix("1.2.3.4/24")
+
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
+			groups:   []string{"default-group"},
+			issuer:   "signer-shasum",
 		},
 		},
+		InvertedGroups: map[string]struct{}{"default-group": {}},
 	}
 	}
 	h := HostInfo{
 	h := HostInfo{
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 
-	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
@@ -552,7 +492,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 
 	oldFw := fw
 	oldFw := fw
-	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
@@ -561,7 +501,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 
 	oldFw = fw
 	oldFw = fw
-	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
@@ -641,8 +581,6 @@ func BenchmarkLookup(b *testing.B) {
 			ml(m, a)
 			ml(m, a)
 		}
 		}
 	})
 	})
-
-	//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
 }
 }
 
 
 func Test_parsePort(t *testing.T) {
 func Test_parsePort(t *testing.T) {
@@ -688,56 +626,59 @@ func Test_parsePort(t *testing.T) {
 func TestNewFirewallFromConfig(t *testing.T) {
 func TestNewFirewallFromConfig(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	// Test a bad rule definition
 	// Test a bad rule definition
-	c := &cert.NebulaCertificate{}
+	c := &dummyCert{}
+	cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
+	require.NoError(t, err)
+
 	conf := config.NewC(l)
 	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
-	_, err := NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 
 	// Test both port and code
 	// Test both port and code
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	// Test missing host, group, cidr, ca_name and ca_sha
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 
 	// Test code/port error
 	// Test code/port error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
 
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 
 	// Test proto error
 	// Test proto error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 
 	// Test cidr parse error
 	// Test cidr parse error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 
 	// Test local_cidr parse error
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 
 	// Test both group and groups
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 }
 
 

+ 21 - 19
go.mod

@@ -1,52 +1,54 @@
 module github.com/slackhq/nebula
 module github.com/slackhq/nebula
 
 
-go 1.22.0
+go 1.23.6
 
 
-toolchain go1.22.2
+toolchain go1.23.7
 
 
 require (
 require (
-	dario.cat/mergo v1.0.0
+	dario.cat/mergo v1.0.1
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
 	github.com/flynn/noise v1.1.0
-	github.com/gaissmai/bart v0.11.1
+	github.com/gaissmai/bart v0.18.1
 	github.com/gogo/protobuf v1.3.2
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.61
+	github.com/miekg/dns v1.1.62
+	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.19.1
+	github.com/prometheus/client_golang v1.20.4
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
-	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
+	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stretchr/testify v1.9.0
 	github.com/stretchr/testify v1.9.0
-	github.com/vishvananda/netlink v1.2.1-beta.2
-	golang.org/x/crypto v0.26.0
+	github.com/vishvananda/netlink v1.3.0
+	golang.org/x/crypto v0.36.0
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
-	golang.org/x/net v0.28.0
-	golang.org/x/sync v0.8.0
-	golang.org/x/sys v0.24.0
-	golang.org/x/term v0.23.0
+	golang.org/x/net v0.37.0
+	golang.org/x/sync v0.12.0
+	golang.org/x/sys v0.31.0
+	golang.org/x/term v0.30.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.34.2
+	google.golang.org/protobuf v1.36.5
 	gopkg.in/yaml.v2 v2.4.0
 	gopkg.in/yaml.v2 v2.4.0
 	gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
 	gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
 )
 )
 
 
 require (
 require (
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
-	github.com/bits-and-blooms/bitset v1.13.0 // indirect
-	github.com/cespare/xxhash/v2 v2.2.0 // indirect
+	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/google/btree v1.1.2 // indirect
 	github.com/google/btree v1.1.2 // indirect
+	github.com/klauspost/compress v1.17.9 // indirect
+	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.5.0 // indirect
-	github.com/prometheus/common v0.48.0 // indirect
-	github.com/prometheus/procfs v0.12.0 // indirect
+	github.com/prometheus/client_model v0.6.1 // indirect
+	github.com/prometheus/common v0.55.0 // indirect
+	github.com/prometheus/procfs v0.15.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
 	golang.org/x/mod v0.18.0 // indirect
 	golang.org/x/mod v0.18.0 // indirect
 	golang.org/x/time v0.5.0 // indirect
 	golang.org/x/time v0.5.0 // indirect

+ 42 - 37
go.sum

@@ -1,6 +1,6 @@
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
-dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
+dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
+dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -14,11 +14,9 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
-github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
-github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
-github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
+github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
-github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
-github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
+github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
+github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -70,6 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
 github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
 github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
+github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
+github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -80,13 +80,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
-github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
+github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
+github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
 github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg=
 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg=
@@ -100,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
-github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
+github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI=
+github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
-github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
+github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
+github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
-github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
+github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
+github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
-github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
+github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
+github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -129,8 +135,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
@@ -139,9 +145,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
 github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
-github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
-github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
+github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
+github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -151,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
-golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -171,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
-golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
+golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
+golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -180,30 +185,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
-golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
+golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
-golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
-golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
+golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
+golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -234,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
-google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
+google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
+google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

+ 223 - 104
handshake_ix.go

@@ -2,10 +2,12 @@ package nebula
 
 
 import (
 import (
 	"net/netip"
 	"net/netip"
+	"slices"
 	"time"
 	"time"
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
@@ -17,23 +19,59 @@ import (
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	err := f.handshakeManager.allocateIndex(hh)
 	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 		return false
 	}
 	}
 
 
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
+	// If we're connecting to a v6 address we must use a v2 cert
+	cs := f.pki.getCertState()
+	v := cs.defaultVersion
+	for _, a := range hh.hostinfo.vpnAddrs {
+		if a.Is6() {
+			v = cert.Version2
+			break
+		}
+	}
+
+	crt := cs.getCertificate(v)
+	if crt == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate is available")
+		return false
+	}
+
+	crtHs := cs.getHandshakeBytes(v)
+	if crtHs == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Failed to create connection state")
+		return false
+	}
 	hh.hostinfo.ConnectionState = ci
 	hh.hostinfo.ConnectionState = ci
 
 
-	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hh.hostinfo.localIndexId,
-		Time:           uint64(time.Now().UnixNano()),
-		Cert:           certState.RawCertificateNoKey,
+	hs := &NebulaHandshake{
+		Details: &NebulaHandshakeDetails{
+			InitiatorIndex: hh.hostinfo.localIndexId,
+			Time:           uint64(time.Now().UnixNano()),
+			Cert:           crtHs,
+			CertVersion:    uint32(v),
+		},
 	}
 	}
 
 
 	if f.multiPort.Tx || f.multiPort.Rx {
 	if f.multiPort.Tx || f.multiPort.Rx {
-		hsProto.InitiatorMultiPort = &MultiPortDetails{
+		hs.Details.InitiatorMultiPort = &MultiPortDetails{
 			RxSupported: f.multiPort.Rx,
 			RxSupported: f.multiPort.Rx,
 			TxSupported: f.multiPort.Tx,
 			TxSupported: f.multiPort.Tx,
 			BasePort:    uint32(f.multiPort.TxBasePort),
 			BasePort:    uint32(f.multiPort.TxBasePort),
@@ -41,15 +79,9 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 		}
 		}
 	}
 	}
 
 
-	hsBytes := []byte{}
-
-	hs := &NebulaHandshake{
-		Details: hsProto,
-	}
-	hsBytes, err = hs.Marshal()
-
+	hsBytes, err := hs.Marshal()
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 		return false
 	}
 	}
@@ -58,7 +90,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 		return false
 	}
 	}
@@ -73,30 +105,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 }
 }
 
 
 func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
+	cs := f.pki.getCertState()
+	crt := cs.GetDefaultCertificate()
+	if crt == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", cs.defaultVersion).
+			Error("Unable to handshake with host because no certificate is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to create connection state")
+		return
+	}
+
 	// Mark packet 1 as seen so it doesn't show up as missed
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 	ci.window.Update(f.l, 1)
 
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to call noise.ReadMessage")
 		return
 		return
 	}
 	}
 
 
 	hs := &NebulaHandshake{}
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	err = hs.Unmarshal(msg)
-	/*
-		l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
-	*/
 	if err != nil || hs.Details == nil {
 	if err != nil || hs.Details == nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed unmarshal handshake message")
 		return
 		return
 	}
 	}
 
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
 	if err != nil {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@@ -109,8 +155,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 		return
 	}
 	}
 
 
-	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
-	if !ok {
+	if remoteCert.Certificate.Version() != ci.myCert.Version() {
+		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
+		rc := cs.getCertificate(remoteCert.Certificate.Version())
+		if rc == nil {
+			f.l.WithError(err).WithField("udpAddr", addr).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
+				Info("Unable to handshake with host due to missing certificate version")
+			return
+		}
+
+		// Record the certificate we are actually using
+		ci.myCert = rc
+	}
+
+	if len(remoteCert.Certificate.Networks()) == 0 {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
 
 
@@ -122,30 +181,54 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 		return
 	}
 	}
 
 
-	vpnIp = vpnIp.Unmap()
-	certName := remoteCert.Details.Name
-	fingerprint, _ := remoteCert.Sha256Sum()
-	issuer := remoteCert.Details.Issuer
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	certName := remoteCert.Certificate.Name()
+	fingerprint := remoteCert.Fingerprint
+	issuer := remoteCert.Certificate.Issuer()
+
+	for _, network := range remoteCert.Certificate.Networks() {
+		vpnAddr := network.Addr()
+		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
+		if found {
+			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("fingerprint", fingerprint).
+				WithField("issuer", issuer).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			return
+		}
+
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
 
 
-	if vpnIp == f.myVpnNet.Addr() {
-		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
 		return
 		return
 	}
 	}
 
 
 	if addr.IsValid() {
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		// addr can be invalid when the tunnel is being relayed.
+		// We only want to apply the remote allow list for direct tunnels here
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 			return
 		}
 		}
 	}
 	}
 
 
 	myIndex, err := generateIndex(f.l)
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -177,19 +260,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		ConnectionState:   ci,
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		vpnIp:             vpnIp,
+		vpnAddrs:          vpnAddrs,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		lastHandshakeTime: hs.Details.Time,
 		multiportTx:       multiportTx,
 		multiportTx:       multiportTx,
 		multiportRx:       multiportRx,
 		multiportRx:       multiportRx,
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}
 	}
 
 
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("issuer", issuer).
@@ -199,13 +282,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		Info("Handshake message received")
 		Info("Handshake message received")
 
 
 	hs.Details.ResponderIndex = myIndex
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = certState.RawCertificateNoKey
+	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
+	if hs.Details.Cert == nil {
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("certVersion", ci.myCert.Version()).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return
+	}
+
+	hs.Details.CertVersion = uint32(ci.myCert.Version())
 	// Update the time in case their clock is way off from ours
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
 
 	hsBytes, err := hs.Marshal()
 	hsBytes, err := hs.Marshal()
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -216,14 +312,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 		return
 	} else if dKey == nil || eKey == nil {
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -247,9 +343,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
 	hostinfo.SetRemote(addr)
-	hostinfo.CreateRemoteCIDR(remoteCert)
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
 	if err != nil {
@@ -263,7 +359,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				// the preferred remote
-				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 			}
 
 
 			msg = existing.HandshakePacket[2]
 			msg = existing.HandshakePacket[2]
@@ -278,11 +374,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 					err = f.outside.WriteTo(msg, addr)
 					err = f.outside.WriteTo(msg, addr)
 				}
 				}
 				if err != nil {
 				if err != nil {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 						WithError(err).Error("Failed to send handshake message")
 				} else {
 				} else {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 						Info("Handshake message sent")
 				}
 				}
@@ -292,16 +388,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 					return
 				}
 				}
-				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
+				f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 					Info("Handshake message sent")
 				return
 				return
 			}
 			}
 		case ErrExistingHostInfo:
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@@ -312,23 +408,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				Info("Handshake too old")
 				Info("Handshake too old")
 
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 			return
 		case ErrLocalIndexCollision:
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
 				Error("Failed to add HostInfo due to localIndex collision")
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 			return
 		default:
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -351,7 +447,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			err = f.outside.WriteTo(msg, addr)
 			err = f.outside.WriteTo(msg, addr)
 		}
 		}
 		if err != nil {
 		if err != nil {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -359,7 +455,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake")
 				WithError(err).Error("Failed to send handshake")
 		} else {
 		} else {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -372,9 +468,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 			return
 		}
 		}
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+		// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
+		// it's correctly marked as working.
+		via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -401,8 +500,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 
 	hostinfo := hh.hostinfo
 	hostinfo := hh.hostinfo
 	if addr.IsValid() {
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 			return false
 		}
 		}
 	}
 	}
@@ -410,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci := hostinfo.ConnectionState
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 			Error("Failed to call noise.ReadMessage")
 
 
@@ -419,7 +519,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		// near future
 		// near future
 		return false
 		return false
 	} else if dKey == nil || eKey == nil {
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 			Error("Noise did not arrive at a key")
 
 
@@ -431,7 +531,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hs := &NebulaHandshake{}
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -452,9 +552,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		)
 		)
 	}
 	}
 
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
 	if err != nil {
-		e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 
 
 		if f.l.Level > logrus.DebugLevel {
 		if f.l.Level > logrus.DebugLevel {
@@ -467,8 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		return true
 		return true
 	}
 	}
 
 
-	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
-	if !ok {
+	if len(remoteCert.Certificate.Networks()) == 0 {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 
 
@@ -476,18 +575,55 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			e = e.WithField("cert", remoteCert)
 			e = e.WithField("cert", remoteCert)
 		}
 		}
 
 
-		e.Info("Invalid vpn ip from host")
+		e.Info("Empty networks from host")
 		return true
 		return true
 	}
 	}
 
 
-	vpnIp = vpnIp.Unmap()
-	certName := remoteCert.Details.Name
-	fingerprint, _ := remoteCert.Sha256Sum()
-	issuer := remoteCert.Details.Issuer
+	vpnNetworks := remoteCert.Certificate.Networks()
+	certName := remoteCert.Certificate.Name()
+	fingerprint := remoteCert.Fingerprint
+	issuer := remoteCert.Certificate.Issuer()
+
+	hostinfo.remoteIndexId = hs.Details.ResponderIndex
+	hostinfo.lastHandshakeTime = hs.Details.Time
+
+	// Store their cert and our symmetric keys
+	ci.peerCert = remoteCert
+	ci.dKey = NewNebulaCipherState(dKey)
+	ci.eKey = NewNebulaCipherState(eKey)
+
+	// Make sure the current udpAddr being used is set for responding
+	if addr.IsValid() {
+		hostinfo.SetRemote(addr)
+	} else {
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+	}
+
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	for _, network := range vpnNetworks {
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		vpnAddr := network.Addr()
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
+		return true
+	}
 
 
 	// Ensure the right host responded
 	// Ensure the right host responded
-	if vpnIp != hostinfo.vpnIp {
-		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
+	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 			Info("Incorrect host responded to handshake")
@@ -496,16 +632,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 
 		// Create a new hostinfo/handshake for the intended vpn ip
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
-			//TODO: this doesnt know if its being added or is being used for caching a packet
+		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes.BlockRemote(addr)
 			newHH.hostinfo.remotes.BlockRemote(addr)
 
 
-			// Get the correct remote list for the host we did handshake with
-			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
-
-			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
+				WithField("vpnNetworks", vpnNetworks).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				Info("Blocked addresses for handshakes")
 				Info("Blocked addresses for handshakes")
 
 
@@ -513,8 +646,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			newHH.packetStore = hh.packetStore
 			newHH.packetStore = hh.packetStore
 			hh.packetStore = []*cachedPacket{}
 			hh.packetStore = []*cachedPacket{}
 
 
-			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-			hostinfo.vpnIp = vpnIp
+			// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnAddrs = vpnAddrs
 			f.sendCloseTunnel(hostinfo)
 			f.sendCloseTunnel(hostinfo)
 		})
 		})
 
 
@@ -525,7 +658,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 	ci.window.Update(f.l, 2)
 
 
 	duration := time.Since(hh.startTime).Nanoseconds()
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("issuer", issuer).
@@ -536,25 +669,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx).
 		WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx).
 		Info("Handshake message received")
 		Info("Handshake message received")
 
 
-	hostinfo.remoteIndexId = hs.Details.ResponderIndex
-	hostinfo.lastHandshakeTime = hs.Details.Time
-
-	// Store their cert and our symmetric keys
-	ci.peerCert = remoteCert
-	ci.dKey = NewNebulaCipherState(dKey)
-	ci.eKey = NewNebulaCipherState(eKey)
-
-	// Make sure the current udpAddr being used is set for responding
-	if addr.IsValid() {
-		hostinfo.SetRemote(addr)
-	} else {
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
-	}
-
 	// Build up the radix for the firewall if we have subnets in the cert
 	// Build up the radix for the firewall if we have subnets in the cert
-	hostinfo.CreateRemoteCIDR(remoteCert)
+	hostinfo.vpnAddrs = vpnAddrs
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
 
-	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
+	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 
 

+ 181 - 124
handshake_manager.go

@@ -7,14 +7,15 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
 	"net/netip"
 	"net/netip"
+	"slices"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
-	"golang.org/x/exp/slices"
 )
 )
 
 
 const (
 const (
@@ -121,18 +122,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) Run(ctx context.Context) {
-	clockSource := time.NewTicker(c.config.tryInterval)
+func (hm *HandshakeManager) Run(ctx context.Context) {
+	clockSource := time.NewTicker(hm.config.tryInterval)
 	defer clockSource.Stop()
 	defer clockSource.Stop()
 
 
 	for {
 	for {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return
 			return
-		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, true)
+		case vpnIP := <-hm.trigger:
+			hm.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now)
+			hm.NextOutboundHandshakeTimerTick(now)
 		}
 		}
 	}
 	}
 }
 }
@@ -140,7 +141,7 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
 	// First remote allow list check before we know the vpnIp
 	if addr.IsValid() {
 	if addr.IsValid() {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 			return
 		}
 		}
@@ -162,14 +163,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
-	c.OutboundHandshakeTimer.Advance(now)
+func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
+	hm.OutboundHandshakeTimer.Advance(now)
 	for {
 	for {
-		vpnIp, has := c.OutboundHandshakeTimer.Purge()
+		vpnIp, has := hm.OutboundHandshakeTimer.Purge()
 		if !has {
 		if !has {
 			break
 			break
 		}
 		}
-		c.handleOutbound(vpnIp, false)
+		hm.handleOutbound(vpnIp, false)
 	}
 	}
 }
 }
 
 
@@ -211,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
 	}
 	}
 
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@@ -226,7 +227,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 
 
 	hh.lastRemotes = remotes
 	hh.lastRemotes = remotes
 
 
-	// TODO: this will generate a load of queries for hosts with only 1 ip
+	// This will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
 	// So we only do it one time after attempting 5 handshakes already.
 	if len(remotes) <= 1 && hh.counter == 5 {
 	if len(remotes) <= 1 && hh.counter == 5 {
@@ -293,59 +294,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
+			// Don't relay to myself
+			if relay == vpnIp {
 				continue
 				continue
 			}
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+
+			// Don't relay through the host I'm trying to connect to
+			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
+			if found {
+				continue
+			}
+
+			relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
 			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hm.f.Handshake(relay)
 				hm.f.Handshake(relay)
 				continue
 				continue
 			}
 			}
-			// Check the relay HostInfo to see if we already established a relay through it
-			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
-				switch existingRelay.State {
-				case Established:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
-				case Requested:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
-
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
-					// Re-send the CreateRelay request, in case the previous one was lost.
-					m := NebulaControl{
-						Type:                NebulaControl_CreateRelayRequest,
-						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
-					}
-					msg, err := m.Marshal()
-					if err != nil {
-						hostinfo.logger(hm.l).
-							WithError(err).
-							Error("Failed to marshal Control message to create relay")
-					} else {
-						// This must send over the hostinfo, not over hm.Hosts[ip]
-						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
-							"relayTo":             vpnIp,
-							"initiatorRelayIndex": existingRelay.LocalIndex,
-							"relay":               relay}).
-							Info("send CreateRelayRequest")
-					}
-				default:
-					hostinfo.logger(hm.l).
-						WithField("vpnIp", vpnIp).
-						WithField("state", existingRelay.State).
-						WithField("relay", relayHostInfo.vpnIp).
-						Errorf("Relay unexpected state")
-				}
-			} else {
+			// Check the relay HostInfo to see if we already established a relay through
+			existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
+			if !ok {
 				// No relays exist or requested yet.
 				// No relays exist or requested yet.
 				if relayHostInfo.remote.IsValid() {
 				if relayHostInfo.remote.IsValid() {
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
@@ -353,16 +321,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 					}
 
 
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
 					m := NebulaControl{
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					}
+
+					switch relayHostInfo.GetCert().Certificate.Version() {
+					case cert.Version1:
+						if !hm.f.myVpnAddrs[0].Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+							continue
+						}
+
+						if !vpnIp.Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+							continue
+						}
+
+						b := hm.f.myVpnAddrs[0].As4()
+						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+						b = vpnIp.As4()
+						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					case cert.Version2:
+						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					default:
+						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						continue
+					}
+
 					msg, err := m.Marshal()
 					msg, err := m.Marshal()
 					if err != nil {
 					if err != nil {
 						hostinfo.logger(hm.l).
 						hostinfo.logger(hm.l).
@@ -371,13 +358,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					} else {
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
+							"relayFrom":           hm.f.myVpnAddrs[0],
 							"relayTo":             vpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
 							"relay":               relay}).
 							Info("send CreateRelayRequest")
 							Info("send CreateRelayRequest")
 					}
 					}
 				}
 				}
+				continue
+			}
+
+			switch existingRelay.State {
+			case Established:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+				hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+			case Disestablished:
+				// Mark this relay as 'requested'
+				relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
+				fallthrough
+			case Requested:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+				// Re-send the CreateRelay request, in case the previous one was lost.
+				m := NebulaControl{
+					Type:                NebulaControl_CreateRelayRequest,
+					InitiatorRelayIndex: existingRelay.LocalIndex,
+				}
+
+				switch relayHostInfo.GetCert().Certificate.Version() {
+				case cert.Version1:
+					if !hm.f.myVpnAddrs[0].Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+						continue
+					}
+
+					if !vpnIp.Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+						continue
+					}
+
+					b := hm.f.myVpnAddrs[0].As4()
+					m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+					b = vpnIp.As4()
+					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+				case cert.Version2:
+					m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+				default:
+					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+					continue
+				}
+				msg, err := m.Marshal()
+				if err != nil {
+					hostinfo.logger(hm.l).
+						WithError(err).
+						Error("Failed to marshal Control message to create relay")
+				} else {
+					// This must send over the hostinfo, not over hm.Hosts[ip]
+					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+					hm.l.WithFields(logrus.Fields{
+						"relayFrom":           hm.f.myVpnAddrs[0],
+						"relayTo":             vpnIp,
+						"initiatorRelayIndex": existingRelay.LocalIndex,
+						"relay":               relay}).
+						Info("send CreateRelayRequest")
+				}
+			case PeerRequested:
+				// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
+				fallthrough
+			default:
+				hostinfo.logger(hm.l).
+					WithField("vpnIp", vpnIp).
+					WithField("state", existingRelay.State).
+					WithField("relay", relay).
+					Errorf("Relay unexpected state")
+
 			}
 			}
 		}
 		}
 	}
 	}
@@ -407,10 +461,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 }
 }
 
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 	hm.Lock()
 
 
-	if hh, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnAddr]; ok {
 		// We are already trying to handshake with this vpn ip
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
 		if cacheCb != nil {
 			cacheCb(hh)
 			cacheCb(hh)
@@ -420,12 +474,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 	}
 	}
 
 
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp:           vpnIp,
+		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}
 	}
 
 
@@ -433,9 +487,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 		hostinfo:  hostinfo,
 		hostinfo:  hostinfo,
 		startTime: time.Now(),
 		startTime: time.Now(),
 	}
 	}
-	hm.vpnIps[vpnIp] = hh
+	hm.vpnIps[vpnAddr] = hh
 	hm.metricInitiated.Inc(1)
 	hm.metricInitiated.Inc(1)
-	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
+	hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
 
 
 	if cacheCb != nil {
 	if cacheCb != nil {
 		cacheCb(hh)
 		cacheCb(hh)
@@ -443,21 +497,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// We can trigger the handshake right now
 	// We can trigger the handshake right now
-	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
 	if !doTrigger {
 	if !doTrigger {
 		// Add any calculated remotes, and trigger early handshake if one found
 		// Add any calculated remotes, and trigger early handshake if one found
-		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
 	}
 	}
 
 
 	if doTrigger {
 	if doTrigger {
 		select {
 		select {
-		case hm.trigger <- vpnIp:
+		case hm.trigger <- vpnAddr:
 		default:
 		default:
 		}
 		}
 	}
 	}
 
 
 	hm.Unlock()
 	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp)
+	hm.lightHouse.QueryServer(vpnAddr)
 	return hostinfo
 	return hostinfo
 }
 }
 
 
@@ -478,14 +532,14 @@ var (
 //
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
-	c.Lock()
-	defer c.Unlock()
+func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 
 	// Check if we already have a tunnel with this vpn ip
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
 	if found && existingHostInfo != nil {
 	if found && existingHostInfo != nil {
 		testHostInfo := existingHostInfo
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
 		for testHostInfo != nil {
@@ -502,31 +556,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			return existingHostInfo, ErrExistingHostInfo
 			return existingHostInfo, ErrExistingHostInfo
 		}
 		}
 
 
-		existingHostInfo.logger(c.l).Info("Taking new handshake")
+		existingHostInfo.logger(hm.l).Info("Taking new handshake")
 	}
 	}
 
 
-	existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
 	if found {
 	if found {
 		// We have a collision, but for a different hostinfo
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 		return existingIndex, ErrLocalIndexCollision
 	}
 	}
 
 
-	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
 	if found && existingPendingIndex.hostinfo != hostinfo {
 	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		// We have a collision, but for a different hostinfo
 		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 	}
 	}
 
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+		hostinfo.logger(hm.l).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 	return existingHostInfo, nil
 }
 }
 
 
@@ -544,7 +598,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(hm.l).
 		hostinfo.logger(hm.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
@@ -581,31 +635,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 	return errors.New("failed to generate unique localIndexId")
 }
 }
 
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	c.Lock()
-	defer c.Unlock()
-	c.unlockedDeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
 }
 }
 
 
-func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	delete(c.vpnIps, hostinfo.vpnIp)
-	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
+func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	for _, addr := range hostinfo.vpnAddrs {
+		delete(hm.vpnIps, addr)
 	}
 	}
 
 
-	delete(c.indexes, hostinfo.localIndexId)
-	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HandshakeHostInfo{}
+	if len(hm.vpnIps) == 0 {
+		hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 	}
 
 
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+	delete(hm.indexes, hostinfo.localIndexId)
+	if len(hm.indexes) == 0 {
+		hm.indexes = map[uint32]*HandshakeHostInfo{}
+	}
+
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Pending hostmap hostInfo deleted")
 			Debug("Pending hostmap hostInfo deleted")
 	}
 	}
 }
 }
 
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
+func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 	if hh != nil {
 		return hh.hostinfo
 		return hh.hostinfo
@@ -634,37 +691,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 	return hm.indexes[index]
 }
 }
 
 
-func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
-	return c.mainHostMap.GetPreferredRanges()
+func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
+	return hm.mainHostMap.GetPreferredRanges()
 }
 }
 
 
-func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
 
-	for _, v := range c.vpnIps {
+	for _, v := range hm.vpnIps {
 		f(v.hostinfo)
 		f(v.hostinfo)
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) ForEachIndex(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
 
-	for _, v := range c.indexes {
+	for _, v := range hm.indexes {
 		f(v.hostinfo)
 		f(v.hostinfo)
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) EmitStats() {
-	c.RLock()
-	hostLen := len(c.vpnIps)
-	indexLen := len(c.indexes)
-	c.RUnlock()
+func (hm *HandshakeManager) EmitStats() {
+	hm.RLock()
+	hostLen := len(hm.vpnIps)
+	indexLen := len(hm.indexes)
+	hm.RUnlock()
 
 
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
-	c.mainHostMap.EmitStats()
+	hm.mainHostMap.EmitStats()
 }
 }
 
 
 // Utility functions below
 // Utility functions below

+ 18 - 11
handshake_manager_test.go

@@ -14,21 +14,20 @@ import (
 
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	ip := netip.MustParseAddr("172.1.1.2")
 	ip := netip.MustParseAddr("172.1.1.2")
 
 
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
-	mainHM := newHostMap(l, vpncidr)
+	mainHM := newHostMap(l)
 	mainHM.preferredRanges.Store(&preferredRanges)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -42,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	i2 := blah.StartHandshake(ip, nil)
 	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 	assert.Same(t, i, i2)
 
 
-	i.remotes = NewRemoteList(nil)
+	i.remotes = NewRemoteList([]netip.Addr{}, nil)
 
 
 	// Adding something to pending should not affect the main hostmap
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
 	assert.Len(t, mainHM.Hosts, 0)
@@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 type mockEncWriter struct {
 type mockEncWriter struct {
 }
 }
 
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
+func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
+
+func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
+	return nil
+}
+
+func (mw *mockEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: cert.Version2}
+}

+ 171 - 104
hostmap.go

@@ -35,6 +35,7 @@ const (
 	Requested = iota
 	Requested = iota
 	PeerRequested
 	PeerRequested
 	Established
 	Established
+	Disestablished
 )
 )
 
 
 const (
 const (
@@ -48,7 +49,7 @@ type Relay struct {
 	State       int
 	State       int
 	LocalIndex  uint32
 	LocalIndex  uint32
 	RemoteIndex uint32
 	RemoteIndex uint32
-	PeerIp      netip.Addr
+	PeerAddr    netip.Addr
 }
 }
 
 
 type HostMap struct {
 type HostMap struct {
@@ -58,7 +59,6 @@ type HostMap struct {
 	RemoteIndexes   map[uint32]*HostInfo
 	RemoteIndexes   map[uint32]*HostInfo
 	Hosts           map[netip.Addr]*HostInfo
 	Hosts           map[netip.Addr]*HostInfo
 	preferredRanges atomic.Pointer[[]netip.Prefix]
 	preferredRanges atomic.Pointer[[]netip.Prefix]
-	vpnCIDR         netip.Prefix
 	l               *logrus.Logger
 	l               *logrus.Logger
 }
 }
 
 
@@ -68,9 +68,12 @@ type HostMap struct {
 type RelayState struct {
 type RelayState struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
-	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
+	relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
+	// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
+	// the RelayState Lock held)
+	relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx  map[uint32]*Relay     // Maps a local index to some Relay info
 }
 }
 
 
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
@@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	delete(rs.relays, ip)
 	delete(rs.relays, ip)
 }
 }
 
 
+func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByAddr[vpnIp]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
+func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByIdx[idx]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
 func (rs *RelayState) CopyAllRelayFor() []*Relay {
 func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
@@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 	return ret
 }
 }
 
 
-func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[ip]
+	r, ok := rs.relayForByAddr[addr]
 	return r, ok
 	return r, ok
 }
 }
 
 
@@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
 func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
-	for relayIp := range rs.relayForByIp {
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
+	for relayIp := range rs.relayForByAddr {
 		currentRelays = append(currentRelays, relayIp)
 		currentRelays = append(currentRelays, relayIp)
 	}
 	}
 	return currentRelays
 	return currentRelays
@@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	if !ok {
 	if !ok {
 		return false
 		return false
 	}
 	}
@@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
 	newRelay.State = Established
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return true
 	return true
 }
 }
 
 
@@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	newRelay.State = Established
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return &newRelay, true
 	return &newRelay, true
 }
 }
 
 
 func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	return r, ok
 	return r, ok
 }
 }
 
 
@@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
-	rs.relayForByIp[ip] = r
+	rs.relayForByAddr[ip] = r
 	rs.relayForByIdx[idx] = r
 	rs.relayForByIdx[idx] = r
 }
 }
 
 
@@ -190,10 +215,16 @@ type HostInfo struct {
 	ConnectionState *ConnectionState
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	remoteIndexId   uint32
 	localIndexId    uint32
 	localIndexId    uint32
-	vpnIp           netip.Addr
-	recvError       atomic.Uint32
-	remoteCidr      *bart.Table[struct{}]
-	relayState      RelayState
+
+	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
+	// The host may have other vpn addresses that are outside our
+	// vpn networks but were removed because they are not usable
+	vpnAddrs  []netip.Addr
+	recvError atomic.Uint32
+
+	// networks are both all vpn and unsafe networks assigned to this host
+	networks   *bart.Table[struct{}]
+	relayState RelayState
 
 
 	// If true, we should send to this remote using multiport
 	// If true, we should send to this remote using multiport
 	multiportTx bool
 	multiportTx bool
@@ -247,28 +278,26 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 	dropped metrics.Counter
 }
 }
 
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
-	hm := newHostMap(l, vpnCIDR)
+func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
+	hm := newHostMap(l)
 
 
 	hm.reload(c, true)
 	hm.reload(c, true)
 	c.RegisterReloadCallback(func(c *config.C) {
 	c.RegisterReloadCallback(func(c *config.C) {
 		hm.reload(c, false)
 		hm.reload(c, false)
 	})
 	})
 
 
-	l.WithField("network", hm.vpnCIDR.String()).
-		WithField("preferredRanges", hm.GetPreferredRanges()).
+	l.WithField("preferredRanges", hm.GetPreferredRanges()).
 		Info("Main HostMap created")
 		Info("Main HostMap created")
 
 
 	return hm
 	return hm
 }
 }
 
 
-func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
+func newHostMap(l *logrus.Logger) *HostMap {
 	return &HostMap{
 	return &HostMap{
 		Indexes:       map[uint32]*HostInfo{},
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		Hosts:         map[netip.Addr]*HostInfo{},
 		Hosts:         map[netip.Addr]*HostInfo{},
-		vpnCIDR:       vpnCIDR,
 		l:             l,
 		l:             l,
 	}
 	}
 }
 }
@@ -311,17 +340,6 @@ func (hm *HostMap) EmitStats() {
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 }
 
 
-func (hm *HostMap) RemoveRelay(localIdx uint32) {
-	hm.Lock()
-	_, ok := hm.Relays[localIdx]
-	if !ok {
-		hm.Unlock()
-		return
-	}
-	delete(hm.Relays, localIdx)
-	hm.Unlock()
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
 	// Delete the host itself, ensuring it's not modified anymore
@@ -341,48 +359,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 }
 }
 
 
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
-	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	// Get the current primary, if it exists
+	oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
+
+	// Every address in the hostinfo gets elevated to primary
+	for _, vpnAddr := range hostinfo.vpnAddrs {
+		//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
+		// indexes so it should be fine.
+		hm.Hosts[vpnAddr] = hostinfo
+	}
+
+	// If we are already primary then we won't bother re-linking
 	if oldHostinfo == hostinfo {
 	if oldHostinfo == hostinfo {
 		return
 		return
 	}
 	}
 
 
+	// Unlink this hostinfo
 	if hostinfo.prev != nil {
 	if hostinfo.prev != nil {
 		hostinfo.prev.next = hostinfo.next
 		hostinfo.prev.next = hostinfo.next
 	}
 	}
-
 	if hostinfo.next != nil {
 	if hostinfo.next != nil {
 		hostinfo.next.prev = hostinfo.prev
 		hostinfo.next.prev = hostinfo.prev
 	}
 	}
 
 
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
+	// If there wasn't a previous primary then clear out any links
 	if oldHostinfo == nil {
 	if oldHostinfo == nil {
+		hostinfo.next = nil
+		hostinfo.prev = nil
 		return
 		return
 	}
 	}
 
 
+	// Relink the hostinfo as primary
 	hostinfo.next = oldHostinfo
 	hostinfo.next = oldHostinfo
 	oldHostinfo.prev = hostinfo
 	oldHostinfo.prev = hostinfo
 	hostinfo.prev = nil
 	hostinfo.prev = nil
 }
 }
 
 
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	for _, addr := range hostinfo.vpnAddrs {
+		h := hm.Hosts[addr]
+		for h != nil {
+			if h == hostinfo {
+				hm.unlockedInnerDeleteHostInfo(h, addr)
+			}
+			h = h.next
+		}
+	}
+}
+
+func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
+	primary, ok := hm.Hosts[addr]
+	isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
 	if ok && primary == hostinfo {
 	if ok && primary == hostinfo {
-		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
-		delete(hm.Hosts, hostinfo.vpnIp)
+		// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, addr)
 		if len(hm.Hosts) == 0 {
 		if len(hm.Hosts) == 0 {
 			hm.Hosts = map[netip.Addr]*HostInfo{}
 			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 		}
 
 
 		if hostinfo.next != nil {
 		if hostinfo.next != nil {
-			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
-			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
+			hm.Hosts[addr] = hostinfo.next
 			// It is primary, there is no previous hostinfo now
 			// It is primary, there is no previous hostinfo now
 			hostinfo.next.prev = nil
 			hostinfo.next.prev = nil
 		}
 		}
 
 
 	} else {
 	} else {
-		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		// Relink if we were in the middle of multiple hostinfos for this vpn addr
 		if hostinfo.prev != nil {
 		if hostinfo.prev != nil {
 			hostinfo.prev.next = hostinfo.next
 			hostinfo.prev.next = hostinfo.next
 		}
 		}
@@ -412,10 +455,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 			Debug("Hostmap hostInfo deleted")
 	}
 	}
 
 
+	if isLastHostinfo {
+		// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
+		// hops as 'Requested' so that new relay tunnels are created in the future.
+		hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
+	}
+	// Clean up any local relay indexes for which I am acting as a relay hop
 	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
 	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
 		delete(hm.Relays, localRelayIdx)
 		delete(hm.Relays, localRelayIdx)
 	}
 	}
@@ -454,11 +503,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
-	return hm.queryVpnIp(vpnIp, nil)
+func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
+	return hm.queryVpnAddr(vpnIp, nil)
 }
 }
 
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 
 
@@ -466,17 +515,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
 	if !ok {
 	if !ok {
 		return nil, nil, errors.New("unable to find host")
 		return nil, nil, errors.New("unable to find host")
 	}
 	}
+
 	for h != nil {
 	for h != nil {
-		r, ok := h.relayState.QueryRelayForByIp(targetIp)
-		if ok && r.State == Established {
-			return h, r, nil
+		for _, targetIp := range targetIps {
+			r, ok := h.relayState.QueryRelayForByIp(targetIp)
+			if ok && r.State == Established {
+				return h, r, nil
+			}
 		}
 		}
 		h = h.next
 		h = h.next
 	}
 	}
+
 	return nil, nil, errors.New("unable to find host with relay")
 	return nil, nil, errors.New("unable to find host with relay")
 }
 }
 
 
-func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
+	for _, relayHostIp := range hi.relayState.CopyRelayIps() {
+		if h, ok := hm.Hosts[relayHostIp]; ok {
+			for h != nil {
+				h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+				h = h.next
+			}
+		}
+	}
+	for _, rs := range hi.relayState.CopyAllRelayFor() {
+		if rs.Type == ForwardingType {
+			if h, ok := hm.Hosts[rs.PeerAddr]; ok {
+				for h != nil {
+					h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+					h = h.next
+				}
+			}
+		}
+	}
+}
+
+func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
 		hm.RUnlock()
@@ -497,25 +571,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
 		remoteCert := hostinfo.ConnectionState.peerCert
-		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
+		dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
 	}
 	}
-
-	existing := hm.Hosts[hostinfo.vpnIp]
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
-	if existing != nil {
-		hostinfo.next = existing
-		existing.prev = hostinfo
+	for _, addr := range hostinfo.vpnAddrs {
+		hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
 	}
 	}
 
 
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
+		hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
 			Debug("Hostmap vpnIp added")
 			Debug("Hostmap vpnIp added")
 	}
 	}
+}
+
+func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
+	existing := hm.Hosts[vpnAddr]
+	hm.Hosts[vpnAddr] = hostinfo
+
+	if existing != nil && existing != hostinfo {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
 
 
 	i := 1
 	i := 1
 	check := hostinfo
 	check := hostinfo
@@ -533,7 +612,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	return *hm.preferredRanges.Load()
 	return *hm.preferredRanges.Load()
 }
 }
 
 
-func (hm *HostMap) ForEachVpnIp(f controlEach) {
+func (hm *HostMap) ForEachVpnAddr(f controlEach) {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 
 
@@ -587,11 +666,11 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
 		}
 		}
 
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp)
+		ifce.lightHouse.QueryServer(i.vpnAddrs[0])
 	}
 	}
 }
 }
 
 
-func (i *HostInfo) GetCert() *cert.NebulaCertificate {
+func (i *HostInfo) GetCert() *cert.CachedCertificate {
 	if i.ConnectionState != nil {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
 		return i.ConnectionState.peerCert
 	}
 	}
@@ -602,7 +681,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	// We copy here because we likely got this remote from a source that reuses the object
 	if i.remote != remote {
 	if i.remote != remote {
 		i.remote = remote
 		i.remote = remote
-		i.remotes.LearnRemote(i.vpnIp, remote)
+		i.remotes.LearnRemote(i.vpnAddrs[0], remote)
 	}
 	}
 }
 }
 
 
@@ -653,29 +732,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 	return true
 }
 }
 
 
-func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
-	if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
+func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
+	if len(networks) == 1 && len(unsafeNetworks) == 0 {
 		// Simple case, no CIDRTree needed
 		// Simple case, no CIDRTree needed
 		return
 		return
 	}
 	}
 
 
-	remoteCidr := new(bart.Table[struct{}])
-	for _, ip := range c.Details.Ips {
-		//TODO: IPV6-WORK what to do when ip is invalid?
-		nip, _ := netip.AddrFromSlice(ip.IP)
-		nip = nip.Unmap()
-		bits, _ := ip.Mask.Size()
-		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
+	i.networks = new(bart.Table[struct{}])
+	for _, network := range networks {
+		i.networks.Insert(network, struct{}{})
 	}
 	}
 
 
-	for _, n := range c.Details.Subnets {
-		//TODO: IPV6-WORK what to do when ip is invalid?
-		nip, _ := netip.AddrFromSlice(n.IP)
-		nip = nip.Unmap()
-		bits, _ := n.Mask.Size()
-		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
+	for _, network := range unsafeNetworks {
+		i.networks.Insert(network, struct{}{})
 	}
 	}
-	i.remoteCidr = remoteCidr
 }
 }
 
 
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@@ -683,13 +753,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 		return logrus.NewEntry(l)
 	}
 	}
 
 
-	li := l.WithField("vpnIp", i.vpnIp).
+	li := l.WithField("vpnAddrs", i.vpnAddrs).
 		WithField("localIndex", i.localIndexId).
 		WithField("localIndex", i.localIndexId).
 		WithField("remoteIndex", i.remoteIndexId)
 		WithField("remoteIndex", i.remoteIndexId)
 
 
 	if connState := i.ConnectionState; connState != nil {
 	if connState := i.ConnectionState; connState != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
-			li = li.WithField("certName", peerCert.Details.Name)
+			li = li.WithField("certName", peerCert.Certificate.Name())
 		}
 		}
 	}
 	}
 
 
@@ -698,9 +768,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 
 // Utility functions
 // Utility functions
 
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
+func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
 	//FIXME: This function is pretty garbage
-	var ips []netip.Addr
+	var finalAddrs []netip.Addr
 	ifaces, _ := net.Interfaces()
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
 		allow := allowList.AllowName(i.Name)
@@ -712,39 +782,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 			continue
 			continue
 		}
 		}
 		addrs, _ := i.Addrs()
 		addrs, _ := i.Addrs()
-		for _, addr := range addrs {
-			var ip net.IP
-			switch v := addr.(type) {
+		for _, rawAddr := range addrs {
+			var addr netip.Addr
+			switch v := rawAddr.(type) {
 			case *net.IPNet:
 			case *net.IPNet:
 				//continue
 				//continue
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			case *net.IPAddr:
 			case *net.IPAddr:
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			}
 			}
 
 
-			nip, ok := netip.AddrFromSlice(ip)
-			if !ok {
+			if !addr.IsValid() {
 				if l.Level >= logrus.DebugLevel {
 				if l.Level >= logrus.DebugLevel {
-					l.WithField("localIp", ip).Debug("ip was invalid for netip")
+					l.WithField("localAddr", rawAddr).Debug("addr was invalid")
 				}
 				}
 				continue
 				continue
 			}
 			}
-			nip = nip.Unmap()
+			addr = addr.Unmap()
 
 
-			//TODO: Filtering out link local for now, this is probably the most correct thing
-			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
-				allow := allowList.Allow(nip)
+			if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
+				isAllowed := allowList.Allow(addr)
 				if l.Level >= logrus.TraceLevel {
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
 				}
 				}
-				if !allow {
+				if !isAllowed {
 					continue
 					continue
 				}
 				}
 
 
-				ips = append(ips, nip)
+				finalAddrs = append(finalAddrs, addr)
 			}
 			}
 		}
 		}
 	}
 	}
-	return ips
+	return finalAddrs
 }
 }

+ 23 - 33
hostmap_test.go

@@ -11,17 +11,14 @@ import (
 
 
 func TestHostMap_MakePrimary(t *testing.T) {
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 
 	f := &Interface{}
 	f := &Interface{}
 
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
 
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 	hm.unlockedAddHostInfo(h1, f)
 
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 	hm.MakePrimary(h3)
 
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 	hm.MakePrimary(h4)
 
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 	hm.MakePrimary(h4)
 
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
 
 
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 
 	f := &Interface{}
 	f := &Interface{}
 
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
-	h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
-	h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
+	h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
+	h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
 
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 	assert.Nil(t, h)
 
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 	assert.Nil(t, h1.next)
 
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 	assert.Nil(t, h3.next)
 
 
 	// Make sure we go h2 -> h4 -> h5
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 	assert.Nil(t, h5.next)
 
 
 	// Make sure we go h2 -> h4
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 	assert.Nil(t, h2.next)
 
 
 	// Make sure we only have h4
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
 	assert.Nil(t, prim.next)
@@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 	assert.Nil(t, h4.next)
 
 
 	// Make sure we have nil
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 	assert.Nil(t, prim)
 }
 }
 
 
@@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
 
 
-	hm := NewHostMapFromConfig(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-		c,
-	)
+	hm := NewHostMapFromConfig(l, c)
 
 
 	toS := func(ipn []netip.Prefix) []string {
 	toS := func(ipn []netip.Prefix) []string {
 		var s []string
 		var s []string

+ 2 - 2
hostmap_tester.go

@@ -9,8 +9,8 @@ import (
 	"net/netip"
 	"net/netip"
 )
 )
 
 
-func (i *HostInfo) GetVpnIp() netip.Addr {
-	return i.vpnIp
+func (i *HostInfo) GetVpnAddrs() []netip.Addr {
+	return i.vpnAddrs
 }
 }
 
 
 func (i *HostInfo) GetLocalIndex() uint32 {
 func (i *HostInfo) GetLocalIndex() uint32 {

+ 31 - 27
inside.go

@@ -21,14 +21,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 	}
 
 
 	// Ignore local broadcast packets
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
-		return
+	if f.dropLocalBroadcast {
+		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
+		if found {
+			return
+		}
 	}
 	}
 
 
-	if fwPacket.RemoteIP == f.myVpnNet.Addr() {
+	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
+	if found {
 		// Immediately forward packets from self to self.
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// This should only happen on Darwin-based and FreeBSD hosts, which
-		// routes packets from the Nebula IP to the Nebula IP through the Nebula
+		// routes packets from the Nebula addr to the Nebula addr through the Nebula
 		// TUN device.
 		// TUN device.
 		if immediatelyForwardToSelf {
 		if immediatelyForwardToSelf {
 			_, err := f.readers[q].Write(packet)
 			_, err := f.readers[q].Write(packet)
@@ -37,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 			}
 			}
 		}
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
 		// Otherwise, drop. On linux, we should never see these packets - Linux
-		// routes packets from the nebula IP to the nebula IP through the loopback device.
+		// routes packets from the nebula addr to the nebula addr through the loopback device.
 		return
 		return
 	}
 	}
 
 
 	// Ignore multicast packets
 	// Ignore multicast packets
-	if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
+	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
 		return
 		return
 	}
 	}
 
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 	})
 
 
 	if hostinfo == nil {
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", fwPacket.RemoteIP).
+			f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
 				WithField("fwPacket", fwPacket).
 				WithField("fwPacket", fwPacket).
-				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
+				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
 		}
 		return
 		return
 	}
 	}
@@ -118,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil)
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil)
 }
 }
 
 
-func (f *Interface) Handshake(vpnIp netip.Addr) {
-	f.getOrHandshake(vpnIp, nil)
+func (f *Interface) Handshake(vpnAddr netip.Addr) {
+	f.getOrHandshake(vpnAddr, nil)
 }
 }
 
 
-// getOrHandshake returns nil if the vpnIp is not routable.
+// getOrHandshake returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !f.myVpnNet.Contains(vpnIp) {
-		vpnIp = f.inside.RouteFor(vpnIp)
-		if !vpnIp.IsValid() {
+func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
+	if !found {
+		vpnAddr = f.inside.RouteFor(vpnAddr)
+		if !vpnAddr.IsValid() {
 			return nil, false
 			return nil, false
 		}
 		}
 	}
 	}
 
 
-	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
+	return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 }
 }
 
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -157,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
 }
 }
 
 
-// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
+	hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 	})
 
 
 	if hostInfo == nil {
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", vpnIp).
-				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
+			f.l.WithField("vpnAddr", vpnAddr).
+				Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
 		}
 		}
 		return
 		return
 	}
 	}
@@ -259,7 +264,6 @@ func (f *Interface) SendVia(via *HostInfo,
 
 
 func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
 func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
 	if ci.eKey == nil {
 	if ci.eKey == nil {
-		//TODO: log warning
 		return
 		return
 	}
 	}
 
 
@@ -303,14 +307,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	f.connectionManager.Out(hostinfo.localIndexId)
 	f.connectionManager.Out(hostinfo.localIndexId)
 
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
-	// all our IPs and enable a faster roaming.
+	// all our addrs and enable a faster roaming.
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp)
+		f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
 		hostinfo.lastRebindCount = f.rebindCount
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 		}
 	}
 	}
 
 
@@ -354,7 +358,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 	} else {
 		// Try to send via a relay
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
 			if err != nil {
 			if err != nil {
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

+ 84 - 77
interface.go

@@ -2,7 +2,6 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
-	"encoding/binary"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -12,6 +11,7 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -28,7 +28,6 @@ type InterfaceConfig struct {
 	Outside                 udp.Conn
 	Outside                 udp.Conn
 	Inside                  overlay.Device
 	Inside                  overlay.Device
 	pki                     *PKI
 	pki                     *PKI
-	Cipher                  string
 	Firewall                *Firewall
 	Firewall                *Firewall
 	ServeDns                bool
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
 	HandshakeManager        *HandshakeManager
@@ -52,25 +51,27 @@ type InterfaceConfig struct {
 }
 }
 
 
 type Interface struct {
 type Interface struct {
-	hostMap            *HostMap
-	outside            udp.Conn
-	inside             overlay.Device
-	pki                *PKI
-	cipher             string
-	firewall           *Firewall
-	connectionManager  *connectionManager
-	handshakeManager   *HandshakeManager
-	serveDns           bool
-	createTime         time.Time
-	lightHouse         *LightHouse
-	myBroadcastAddr    netip.Addr
-	myVpnNet           netip.Prefix
-	dropLocalBroadcast bool
-	dropMulticast      bool
-	routines           int
-	disconnectInvalid  atomic.Bool
-	closed             atomic.Bool
-	relayManager       *relayManager
+	hostMap               *HostMap
+	outside               udp.Conn
+	inside                overlay.Device
+	pki                   *PKI
+	firewall              *Firewall
+	connectionManager     *connectionManager
+	handshakeManager      *HandshakeManager
+	serveDns              bool
+	createTime            time.Time
+	lightHouse            *LightHouse
+	myBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
+	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	dropLocalBroadcast    bool
+	dropMulticast         bool
+	routines              int
+	disconnectInvalid     atomic.Bool
+	closed                atomic.Bool
+	relayManager          *relayManager
 
 
 	tryPromoteEvery atomic.Uint32
 	tryPromoteEvery atomic.Uint32
 	reQueryEvery    atomic.Uint32
 	reQueryEvery    atomic.Uint32
@@ -114,9 +115,11 @@ type EncWriter interface {
 		out []byte,
 		out []byte,
 		nocopy bool,
 		nocopy bool,
 	)
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
+	SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp netip.Addr)
+	Handshake(vpnAddr netip.Addr)
+	GetHostInfo(vpnAddr netip.Addr) *HostInfo
+	GetCertState() *CertState
 }
 }
 
 
 type sendRecvErrorConfig uint8
 type sendRecvErrorConfig uint8
@@ -127,10 +130,10 @@ const (
 	sendRecvErrorPrivate
 	sendRecvErrorPrivate
 )
 )
 
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
 	switch s {
 	switch s {
 	case sendRecvErrorPrivate:
 	case sendRecvErrorPrivate:
-		return ip.Addr().IsPrivate()
+		return endpoint.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 	case sendRecvErrorAlways:
 		return true
 		return true
 	case sendRecvErrorNever:
 	case sendRecvErrorNever:
@@ -167,47 +170,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 		return nil, errors.New("no firewall rules")
 	}
 	}
 
 
-	certificate := c.pki.GetCertState().Certificate
-
-	myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
-	if !ok {
-		return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
-	}
-
-	myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
-	if !ok {
-		return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
-	}
-
-	myVpnAddr = myVpnAddr.Unmap()
-	myVpnMask = myVpnMask.Unmap()
-
-	if myVpnAddr.BitLen() != myVpnMask.BitLen() {
-		return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
-	}
-
-	ones, _ := certificate.Details.Ips[0].Mask.Size()
-	myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
-
+	cs := c.pki.getCertState()
 	ifce := &Interface{
 	ifce := &Interface{
-		pki:                c.pki,
-		hostMap:            c.HostMap,
-		outside:            c.Outside,
-		inside:             c.Inside,
-		cipher:             c.Cipher,
-		firewall:           c.Firewall,
-		serveDns:           c.ServeDns,
-		handshakeManager:   c.HandshakeManager,
-		createTime:         time.Now(),
-		lightHouse:         c.lightHouse,
-		dropLocalBroadcast: c.DropLocalBroadcast,
-		dropMulticast:      c.DropMulticast,
-		routines:           c.routines,
-		version:            c.version,
-		writers:            make([]udp.Conn, c.routines),
-		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnNet:           myVpnNet,
-		relayManager:       c.relayManager,
+		pki:                   c.pki,
+		hostMap:               c.HostMap,
+		outside:               c.Outside,
+		inside:                c.Inside,
+		firewall:              c.Firewall,
+		serveDns:              c.ServeDns,
+		handshakeManager:      c.HandshakeManager,
+		createTime:            time.Now(),
+		lightHouse:            c.lightHouse,
+		dropLocalBroadcast:    c.DropLocalBroadcast,
+		dropMulticast:         c.DropMulticast,
+		routines:              c.routines,
+		version:               c.version,
+		writers:               make([]udp.Conn, c.routines),
+		readers:               make([]io.ReadWriteCloser, c.routines),
+		myVpnNetworks:         cs.myVpnNetworks,
+		myVpnNetworksTable:    cs.myVpnNetworksTable,
+		myVpnAddrs:            cs.myVpnAddrs,
+		myVpnAddrsTable:       cs.myVpnAddrsTable,
+		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
+		relayManager:          c.relayManager,
 
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 
@@ -221,12 +206,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 		l: c.l,
 	}
 	}
 
 
-	if myVpnAddr.Is4() {
-		addr := myVpnNet.Masked().Addr().As4()
-		binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
-		ifce.myBroadcastAddr = netip.AddrFrom4(addr)
-	}
-
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -247,7 +226,7 @@ func (f *Interface) activate() {
 		f.l.WithError(err).Error("Failed to get udp listen address")
 		f.l.WithError(err).Error("Failed to get udp listen address")
 	}
 	}
 
 
-	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
+	f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("boringcrypto", boringEnabled()).
 		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
 		Info("Nebula interface is active")
@@ -290,16 +269,22 @@ func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 	runtime.LockOSThread()
 
 
 	var li udp.Conn
 	var li udp.Conn
-	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 	if i > 0 {
 		li = f.writers[i]
 		li = f.writers[i]
 	} else {
 	} else {
 		li = f.outside
 		li = f.outside
 	}
 	}
 
 
+	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	lhh := f.lightHouse.NewRequestHandler()
 	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
+	plaintext := make([]byte, udp.MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+
+	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
+		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+	})
 }
 }
 
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -356,7 +341,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 		return
 	}
 	}
 
 
-	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
 		return
@@ -441,6 +426,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	var rawStats func()
 	var rawStats func()
 
 
 	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
 	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
+	certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
+	certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
 
 
 	for {
 	for {
 		select {
 		select {
@@ -450,17 +437,37 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
 			udpStats()
+
+			certState := f.pki.getCertState()
+			defaultCrt := certState.GetDefaultCertificate()
+			certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
+			certDefaultVersion.Update(int64(defaultCrt.Version()))
+
 			if f.udpRaw != nil {
 			if f.udpRaw != nil {
 				if rawStats == nil {
 				if rawStats == nil {
 					rawStats = udp.NewRawStatsEmitter(f.udpRaw)
 					rawStats = udp.NewRawStatsEmitter(f.udpRaw)
 				}
 				}
 				rawStats()
 				rawStats()
 			}
 			}
-			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
+
+			// Report the max certificate version we are capable of using
+			if certState.v2Cert != nil {
+				certMaxVersion.Update(int64(certState.v2Cert.Version()))
+			} else {
+				certMaxVersion.Update(int64(certState.v1Cert.Version()))
+			}
 		}
 		}
 	}
 	}
 }
 }
 
 
+func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return f.hostMap.QueryVpnAddr(vpnIp)
+}
+
+func (f *Interface) GetCertState() *CertState {
+	return f.pki.getCertState()
+}
+
 func (f *Interface) Close() error {
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 	f.closed.Store(true)
 
 

+ 0 - 2
iputil/packet.go

@@ -6,8 +6,6 @@ import (
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
 
 
-//TODO: IPV6-WORK can probably delete this
-
 const (
 const (
 	// Need 96 bytes for the largest reject packet:
 	// Need 96 bytes for the largest reject packet:
 	// - 20 byte ipv4 header
 	// - 20 byte ipv4 header

File diff suppressed because it is too large
+ 364 - 240
lighthouse.go


+ 173 - 146
lighthouse_test.go

@@ -7,6 +7,8 @@ import (
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 
 
+	"github.com/gaissmai/bart"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
@@ -14,62 +16,51 @@ import (
 	"gopkg.in/yaml.v2"
 	"gopkg.in/yaml.v2"
 )
 )
 
 
-//TODO: Add a test to ensure udpAddr is copied and not reused
-
 func TestOldIPv4Only(t *testing.T) {
 func TestOldIPv4Only(t *testing.T) {
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
-	var m Ip4AndPort
+	var m V4AddrPort
 	err := m.Unmarshal(b)
 	err := m.Unmarshal(b)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	ip := netip.MustParseAddr("10.1.1.1")
 	ip := netip.MustParseAddr("10.1.1.1")
 	bp := ip.As4()
 	bp := ip.As4()
-	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
-}
-
-func TestNewLhQuery(t *testing.T) {
-	myIp, err := netip.ParseAddr("192.1.1.1")
-	assert.NoError(t, err)
-
-	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIp)
-
-	// The result should be a nebulameta protobuf
-	assert.IsType(t, &NebulaMeta{}, a)
-
-	// It should also Marshal fine
-	b, err := a.Marshal()
-	assert.Nil(t, err)
-
-	// and then Unmarshal fine
-	n := &NebulaMeta{}
-	err = n.Unmarshal(b)
-	assert.Nil(t, err)
-
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
 }
 }
 
 
 func Test_lhStaticMapping(t *testing.T) {
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	lh2 := "10.128.0.3"
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 }
 
 
 func TestReloadLighthouseInterval(t *testing.T) {
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
@@ -79,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	}
 	}
 
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
 	lh.ifce = &mockEncWriter{}
 
 
@@ -99,9 +90,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	if !assert.NoError(b, err) {
 	if !assert.NoError(b, err) {
 		b.Fatal()
 		b.Fatal()
 	}
 	}
@@ -110,46 +107,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
 
 
 	vpnIp3 := netip.MustParseAddr("0.0.0.3")
 	vpnIp3 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp3] = NewRemoteList(nil)
+	lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
 	lh.addrMap[vpnIp3].unlockedSetV4(
 	lh.addrMap[vpnIp3].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
-			NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
+			netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
 	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
 	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
 	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
 	vpnIp2 := netip.MustParseAddr("0.0.0.3")
 	vpnIp2 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp2] = NewRemoteList(nil)
+	lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
 	lh.addrMap[vpnIp2].unlockedSetV4(
 	lh.addrMap[vpnIp2].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
-			NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
+			netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 
 
+	hi := []netip.Addr{vpnIp2}
 	b.Run("notfound", func(b *testing.B) {
 	b.Run("notfound", func(b *testing.B) {
 		lhh := lh.NewRequestHandler()
 		lhh := lh.NewRequestHandler()
 		req := &NebulaMeta{
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
 			Details: &NebulaMetaDetails{
-				VpnIp:       4,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  4,
+				V4AddrPorts: nil,
 			},
 			},
 		}
 		}
 		p, err := req.Marshal()
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		assert.NoError(b, err)
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 		}
 	})
 	})
 	b.Run("found", func(b *testing.B) {
 	b.Run("found", func(b *testing.B) {
@@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		req := &NebulaMeta{
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
 			Details: &NebulaMetaDetails{
-				VpnIp:       3,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  3,
+				V4AddrPorts: nil,
 			},
 			},
 		}
 		}
 		p, err := req.Marshal()
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		assert.NoError(b, err)
 
 
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 		}
 	})
 	})
 }
 }
@@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	lh.ifce = &mockEncWriter{}
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 	lhh := lh.NewRequestHandler()
 
 
 	// Test that my first update responds with just that
 	// Test that my first update responds with just that
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
 
 
 	// Ensure we don't accumulate addresses
 	// Ensure we don't accumulate addresses
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
 
 
 	// Grow it back to 2
 	// Grow it back to 2
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Update a different host and ask about it
 	// Update a different host and ask about it
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 
 	// Have both hosts ask about the other
 	// Have both hosts ask about the other
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 
 	// Make sure we didn't get changed
 	// Make sure we didn't get changed
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Ensure proper ordering and limiting
 	// Ensure proper ordering and limiting
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 	assertIp4InArray(
 		t,
 		t,
-		r.msg.Details.Ip4AndPorts,
+		r.msg.Details.V4AddrPorts,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 	)
 
 
@@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	good := netip.MustParseAddrPort("1.128.0.99:4242")
 	good := netip.MustParseAddrPort("1.128.0.99:4242")
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
 }
 }
 
 
 func TestLighthouse_reload(t *testing.T) {
 func TestLighthouse_reload(t *testing.T) {
@@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	nc := map[interface{}]interface{}{
 	nc := map[interface{}]interface{}{
@@ -290,13 +306,16 @@ func TestLighthouse_reload(t *testing.T) {
 }
 }
 
 
 func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
 func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
-	//TODO: IPV6-WORK
-	bip := queryVpnIp.As4()
 	req := &NebulaMeta{
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostQuery,
-		Details: &NebulaMetaDetails{
-			VpnIp: binary.BigEndian.Uint32(bip[:]),
-		},
+		Type:    NebulaMeta_HostQuery,
+		Details: &NebulaMetaDetails{},
+	}
+
+	if queryVpnIp.Is4() {
+		bip := queryVpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
 	}
 	}
 
 
 	b, err := req.Marshal()
 	b, err := req.Marshal()
@@ -308,23 +327,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
 	w := &testEncWriter{
 	w := &testEncWriter{
 		metaFilter: &filter,
 		metaFilter: &filter,
 	}
 	}
-	lhh.HandleRequest(fromAddr, myVpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
 	return w.lastReply
 	return w.lastReply
 }
 }
 
 
 func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
 func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
-	//TODO: IPV6-WORK
-	bip := vpnIp.As4()
 	req := &NebulaMeta{
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostUpdateNotification,
-		Details: &NebulaMetaDetails{
-			VpnIp:       binary.BigEndian.Uint32(bip[:]),
-			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
-		},
+		Type:    NebulaMeta_HostUpdateNotification,
+		Details: &NebulaMetaDetails{},
 	}
 	}
 
 
-	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
+	if vpnIp.Is4() {
+		bip := vpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
+	}
+
+	for _, v := range addrs {
+		if v.Addr().Is4() {
+			req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
+		} else {
+			req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
+		}
 	}
 	}
 
 
 	b, err := req.Marshal()
 	b, err := req.Marshal()
@@ -333,75 +358,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
 	}
 	}
 
 
 	w := &testEncWriter{}
 	w := &testEncWriter{}
-	lhh.HandleRequest(fromAddr, vpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
 }
 }
 
 
-//TODO: this is a RemoteList test
-//func Test_lhRemoteAllowList(t *testing.T) {
-//	l := NewLogger()
-//	c := NewConfig(l)
-//	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
-//		"10.20.0.0/12": false,
-//	}
-//	allowList, err := c.GetAllowList("remoteallowlist", false)
-//	assert.Nil(t, err)
-//
-//	lh1 := "10.128.0.2"
-//	lh1IP := net.ParseIP(lh1)
-//
-//	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-//
-//	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-//	lh.SetRemoteAllowList(allowList)
-//
-//	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
-//	remote1IP := net.ParseIP("10.20.0.3")
-//	remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
-//	remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
-//	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
-//	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
-//
-//	// Make sure a good ip enters the cache and addrMap
-//	remote2IP := net.ParseIP("10.128.0.3")
-//	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
-//
-//	// Another good ip gets into the cache, ordering is inverted
-//	remote3IP := net.ParseIP("10.128.0.4")
-//	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
-//
-//	// If we exceed the length limit we should only have the most recent addresses
-//	addedAddrs := []*udpAddr{}
-//	for i := 0; i < 11; i++ {
-//		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
-//		lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
-//		// The first entry here is a duplicate, don't add it to the assert list
-//		if i != 0 {
-//			addedAddrs = append(addedAddrs, remoteUDPAddr)
-//		}
-//	}
-//
-//	// We should only have the last 10 of what we tried to add
-//	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
-//	assertUdpAddrInArray(
-//		t,
-//		lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
-//		addedAddrs[0],
-//		addedAddrs[1],
-//		addedAddrs[2],
-//		addedAddrs[3],
-//		addedAddrs[4],
-//		addedAddrs[5],
-//		addedAddrs[6],
-//		addedAddrs[7],
-//		addedAddrs[8],
-//		addedAddrs[9],
-//	)
-//}
-
 type testLhReply struct {
 type testLhReply struct {
 	nebType    header.MessageType
 	nebType    header.MessageType
 	nebSubType header.MessageSubType
 	nebSubType header.MessageSubType
@@ -410,8 +369,9 @@ type testLhReply struct {
 }
 }
 
 
 type testEncWriter struct {
 type testEncWriter struct {
-	lastReply  testLhReply
-	metaFilter *NebulaMeta_MessageType
+	lastReply       testLhReply
+	metaFilter      *NebulaMeta_MessageType
+	protocolVersion cert.Version
 }
 }
 
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@@ -426,7 +386,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 		tw.lastReply = testLhReply{
 		tw.lastReply = testLhReply{
 			nebType:    t,
 			nebType:    t,
 			nebSubType: st,
 			nebSubType: st,
-			vpnIp:      hostinfo.vpnIp,
+			vpnIp:      hostinfo.vpnAddrs[0],
 			msg:        msg,
 			msg:        msg,
 		}
 		}
 	}
 	}
@@ -436,7 +396,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 	}
 }
 }
 
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -453,17 +413,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	}
 	}
 }
 }
 
 
+func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return nil
+}
+
+func (tw *testEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: tw.protocolVersion}
+}
+
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
+func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 	if !assert.Len(t, have, len(want)) {
 		return
 		return
 	}
 	}
 
 
 	for k, w := range want {
 	for k, w := range want {
-		//TODO: IPV6-WORK
-		h := AddrPortFromIp4AndPort(have[k])
+		h := protoV4AddrPortToNetAddrPort(have[k])
 		if !(h == w) {
 		if !(h == w) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 		}
 	}
 	}
 }
 }
+
+func Test_findNetworkUnion(t *testing.T) {
+	var out netip.Addr
+	var ok bool
+
+	tenDot := netip.MustParsePrefix("10.0.0.0/8")
+	oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
+	fe80 := netip.MustParsePrefix("fe80::/8")
+	fc00 := netip.MustParsePrefix("fc00::/7")
+
+	a1 := netip.MustParseAddr("10.0.0.1")
+	afe81 := netip.MustParseAddr("fe80::1")
+
+	//simple
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed lengths
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed family
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//ordering
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//some mismatches
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//falsey cases
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+}

+ 5 - 34
main.go

@@ -2,7 +2,6 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
-	"encoding/binary"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
@@ -61,25 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
 	}
 
 
-	certificate := pki.GetCertState().Certificate
-	fw, err := NewFirewallFromConfig(l, certificate, c)
+	fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
 
-	ones, _ := certificate.Details.Ips[0].Mask.Size()
-	addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
-	if !ok {
-		err = util.NewContextualError(
-			"Invalid ip address in certificate",
-			m{"vpnIp": certificate.Details.Ips[0].IP},
-			nil,
-		)
-		return nil, err
-	}
-	tunCidr := netip.PrefixFrom(addr, ones)
-
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@@ -142,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			deviceFactory = overlay.NewDeviceFromConfig
 			deviceFactory = overlay.NewDeviceFromConfig
 		}
 		}
 
 
-		tun, err = deviceFactory(c, l, tunCidr, routines)
+		tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
 		if err != nil {
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
 		}
@@ -197,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 		}
 	}
 	}
 
 
-	hostMap := NewHostMapFromConfig(l, tunCidr, c)
+	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
 	}
@@ -242,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		Inside:                  tun,
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		Outside:                 udpConns[0],
 		pki:                     pki,
 		pki:                     pki,
-		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
 		HandshakeManager:        handshakeManager,
@@ -264,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		l:                     l,
 		l:                     l,
 	}
 	}
 
 
-	switch ifConfig.Cipher {
-	case "aes":
-		noiseEndianness = binary.BigEndian
-	case "chachapoly":
-		noiseEndianness = binary.LittleEndian
-	default:
-		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
-	}
-
 	var ifce *Interface
 	var ifce *Interface
 	if !configTest {
 	if !configTest {
 		ifce, err = NewInterface(ctx, ifConfig)
 		ifce, err = NewInterface(ctx, ifConfig)
@@ -280,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			return nil, fmt.Errorf("failed to initialize interface: %s", err)
 			return nil, fmt.Errorf("failed to initialize interface: %s", err)
 		}
 		}
 
 
-		// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
-		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
 		ifce.writers = udpConns
 		lightHouse.ifce = ifce
 		lightHouse.ifce = ifce
 
 
@@ -326,8 +300,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		go handshakeManager.Run(ctx)
 		go handshakeManager.Run(ctx)
 	}
 	}
 
 
-	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
-	// a context so that they can exit when the context is Done.
 	statsStart, err := startStats(l, c, buildVersion, configTest)
 	statsStart, err := startStats(l, c, buildVersion, configTest)
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
 		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
@@ -337,7 +309,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
 
 	attachCommands(l, c, ssh, ifce)
 	attachCommands(l, c, ssh, ifce)
@@ -346,7 +317,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	var dnsStart func()
 	var dnsStart func()
 	if lightHouse.amLighthouse && serveDns {
 	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, c)
+		dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
 	}
 	}
 
 
 	return &Control{
 	return &Control{

+ 0 - 2
message_metrics.go

@@ -7,8 +7,6 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 )
 )
 
 
-//TODO: this can probably move into the header package
-
 type MessageMetrics struct {
 type MessageMetrics struct {
 	rx [][]metrics.Counter
 	rx [][]metrics.Counter
 	tx [][]metrics.Counter
 	tx [][]metrics.Counter

File diff suppressed because it is too large
+ 560 - 176
nebula.pb.go


+ 23 - 9
nebula.proto

@@ -23,19 +23,28 @@ message NebulaMeta {
 }
 }
 
 
 message NebulaMetaDetails {
 message NebulaMetaDetails {
-  uint32 VpnIp = 1;
-  repeated Ip4AndPort Ip4AndPorts = 2;
-  repeated Ip6AndPort Ip6AndPorts = 4;
-  repeated uint32 RelayVpnIp = 5;
+  uint32 OldVpnAddr = 1 [deprecated = true];
+  Addr VpnAddr = 6;
+
+  repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
+  repeated Addr RelayVpnAddrs = 7;
+
+  repeated V4AddrPort V4AddrPorts = 2;
+  repeated V6AddrPort V6AddrPorts = 4;
   uint32 counter = 3;
   uint32 counter = 3;
 }
 }
 
 
-message Ip4AndPort {
-  uint32 Ip = 1;
+message Addr {
+  uint64 Hi = 1;
+  uint64 Lo = 2;
+}
+
+message V4AddrPort {
+  uint32 Addr = 1;
   uint32 Port = 2;
   uint32 Port = 2;
 }
 }
 
 
-message Ip6AndPort {
+message V6AddrPort {
   uint64 Hi = 1;
   uint64 Hi = 1;
   uint64 Lo = 2;
   uint64 Lo = 2;
   uint32 Port = 3;
   uint32 Port = 3;
@@ -69,6 +78,7 @@ message NebulaHandshakeDetails {
   uint32 ResponderIndex = 3;
   uint32 ResponderIndex = 3;
   uint64 Cookie = 4;
   uint64 Cookie = 4;
   uint64 Time = 5;
   uint64 Time = 5;
+  uint32 CertVersion = 8;
 
 
   MultiPortDetails InitiatorMultiPort = 6;
   MultiPortDetails InitiatorMultiPort = 6;
   MultiPortDetails ResponderMultiPort = 7;
   MultiPortDetails ResponderMultiPort = 7;
@@ -84,6 +94,10 @@ message NebulaControl {
 
 
   uint32 InitiatorRelayIndex = 2;
   uint32 InitiatorRelayIndex = 2;
   uint32 ResponderRelayIndex = 3;
   uint32 ResponderRelayIndex = 3;
-  uint32 RelayToIp = 4;
-  uint32 RelayFromIp = 5;
+
+  uint32 OldRelayToAddr = 4 [deprecated = true];
+  uint32 OldRelayFromAddr = 5 [deprecated = true];
+
+  Addr RelayToAddr = 6;
+  Addr RelayFromAddr = 7;
 }
 }

+ 50 - 0
noiseutil/pkcs11.go

@@ -0,0 +1,50 @@
+package noiseutil
+
+import (
+	"crypto/ecdh"
+	"fmt"
+	"strings"
+
+	"github.com/slackhq/nebula/pkclient"
+
+	"github.com/flynn/noise"
+)
+
+// DHP256PKCS11 is the NIST P-256 ECDH function
+var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32)
+
+type nistP11Curve struct {
+	nistCurve
+}
+
+func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve {
+	return nistP11Curve{
+		newNISTCurve(name, curve, byteLen),
+	}
+}
+
+func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) {
+	//for this function "privkey" is actually a pkcs11 URI
+	pkStr := string(privkey)
+
+	//to set up a handshake, we need to also do non-pkcs11-DH. Handle that here.
+	if !strings.HasPrefix(pkStr, "pkcs11:") {
+		return DHP256.DH(privkey, pubkey)
+	}
+	ecdhPubKey, err := c.curve.NewPublicKey(pubkey)
+	if err != nil {
+		return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err)
+	}
+
+	//this is not the most performant way to do this (a long-lived client would be better)
+	//but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users.
+	client, err := pkclient.FromUrl(pkStr)
+	if err != nil {
+		return nil, err
+	}
+	defer func(client *pkclient.PKClient) {
+		_ = client.Close()
+	}(client)
+
+	return client.DeriveNoise(ecdhPubKey.Bytes())
+}

+ 156 - 137
outside.go

@@ -3,46 +3,25 @@ package nebula
 import (
 import (
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
-	"fmt"
 	"net/netip"
 	"net/netip"
 	"time"
 	"time"
 
 
-	"github.com/flynn/noise"
+	"github.com/google/gopacket/layers"
+	"golang.org/x/net/ipv6"
+
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
-	"google.golang.org/protobuf/proto"
 )
 )
 
 
 const (
 const (
 	minFwPacketLen = 4
 	minFwPacketLen = 4
 )
 )
 
 
-// TODO: IPV6-WORK this can likely be removed now
-func readOutsidePackets(f *Interface) udp.EncReader {
-	return func(
-		addr netip.AddrPort,
-		out []byte,
-		packet []byte,
-		header *header.H,
-		fwPacket *firewall.Packet,
-		lhh udp.LightHouseHandlerFunc,
-		nb []byte,
-		q int,
-		localCache firewall.ConntrackCache,
-	) {
-		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
-	}
-}
-
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	err := h.Parse(packet)
 	if err != nil {
 	if err != nil {
-		// TODO: best if we return this and let caller log
-		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
 		if len(packet) > 1 {
 			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
 			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
@@ -52,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	if ip.IsValid() {
 	if ip.IsValid() {
-		if f.myVpnNet.Contains(ip.Addr()) {
+		_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
+		if found {
 			if f.l.Level >= logrus.DebugLevel {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
 			}
@@ -109,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			if !ok {
 			if !ok {
 				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
 				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
 				// its internal mapping. This should never happen.
 				// its internal mapping. This should never happen.
-				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
 				return
 				return
 			}
 			}
 
 
@@ -121,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				return
 				return
 			case ForwardingType:
 			case ForwardingType:
 				// Find the target HostInfo relay object
 				// Find the target HostInfo relay object
-				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
+				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
 				if err != nil {
 				if err != nil {
-					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
+					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
 					return
 					return
 				}
 				}
 
 
@@ -139,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 					}
 					}
 				} else {
 				} else {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
+					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 					return
 					return
 				}
 				}
 			}
 			}
@@ -156,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
 				Error("Failed to decrypt lighthouse packet")
-
-			//TODO: maybe after build 64 is out? 06/14/2018 - NB
-			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
 			return
 			return
 		}
 		}
 
 
-		lhf(ip, hostinfo.vpnIp, d)
+		lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
 
 
 		// Fallthrough to the bottom to record incoming traffic
 		// Fallthrough to the bottom to record incoming traffic
 
 
@@ -177,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
 				Error("Failed to decrypt test packet")
-
-			//TODO: maybe after build 64 is out? 06/14/2018 - NB
-			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
 			return
 			return
 		}
 		}
 
 
@@ -229,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				Error("Failed to decrypt Control packet")
 				Error("Failed to decrypt Control packet")
 			return
 			return
 		}
 		}
-		m := &NebulaControl{}
-		err = m.Unmarshal(d)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
-			break
-		}
 
 
-		f.relayManager.HandleControlMsg(hostinfo, m, f)
+		f.relayManager.HandleControlMsg(hostinfo, d, f)
 
 
 	default:
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -253,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	if final {
 	if final {
-		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
-		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
+		// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
+		f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
 	}
 	}
 }
 }
 
 
@@ -263,35 +231,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 }
 
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
-	if ip.IsValid() && hostinfo.remote != ip {
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
+	if udpAddr.IsValid() && hostinfo.remote != udpAddr {
 		if hostinfo.multiportRx {
 		if hostinfo.multiportRx {
 			// If the remote is sending with multiport, we aren't roaming unless
 			// If the remote is sending with multiport, we aren't roaming unless
 			// the IP has changed
 			// the IP has changed
-			if hostinfo.remote.Addr().Compare(ip.Addr()) == 0 {
+			if hostinfo.remote.Addr().Compare(udpAddr.Addr()) == 0 {
 				return
 				return
 			}
 			}
 			// Keep the port from the original hostinfo, because the remote is transmitting from multiport ports
 			// Keep the port from the original hostinfo, because the remote is transmitting from multiport ports
-			ip = netip.AddrPortFrom(ip.Addr(), hostinfo.remote.Port())
+			udpAddr = netip.AddrPortFrom(udpAddr.Addr(), hostinfo.remote.Port())
 		}
 		}
 
 
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 			return
 		}
 		}
-		if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+
+		if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			}
 			return
 			return
 		}
 		}
 
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 			Info("Host roamed to new udp ip/port.")
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(ip)
+		hostinfo.SetRemote(udpAddr)
 	}
 	}
 
 
 }
 }
@@ -311,24 +280,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
 	return true
 	return true
 }
 }
 
 
+var (
+	ErrPacketTooShort          = errors.New("packet is too short")
+	ErrUnknownIPVersion        = errors.New("packet is an unknown ip version")
+	ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
+	ErrIPv4PacketTooShort      = errors.New("ipv4 packet is too short")
+	ErrIPv6PacketTooShort      = errors.New("ipv6 packet is too short")
+	ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
+)
+
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
-	// Do we at least have an ipv4 header worth of data?
-	if len(data) < ipv4.HeaderLen {
-		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
+	if len(data) < 1 {
+		return ErrPacketTooShort
+	}
+
+	version := int((data[0] >> 4) & 0x0f)
+	switch version {
+	case ipv4.Version:
+		return parseV4(data, incoming, fp)
+	case ipv6.Version:
+		return parseV6(data, incoming, fp)
+	}
+	return ErrUnknownIPVersion
+}
+
+func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
+	dataLen := len(data)
+	if dataLen < ipv6.HeaderLen {
+		return ErrIPv6PacketTooShort
 	}
 	}
 
 
-	// Is it an ipv4 packet?
-	if int((data[0]>>4)&0x0f) != 4 {
-		return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
+	if incoming {
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
+	} else {
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
+	}
+
+	protoAt := 6             // NextHeader is at 6 bytes into the ipv6 header
+	offset := ipv6.HeaderLen // Start at the end of the ipv6 header
+	next := 0
+	for {
+		if dataLen < offset {
+			break
+		}
+
+		proto := layers.IPProtocol(data[protoAt])
+		//fmt.Println(proto, protoAt)
+		switch proto {
+		case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
+			fp.Protocol = uint8(proto)
+			fp.RemotePort = 0
+			fp.LocalPort = 0
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolTCP, layers.IPProtocolUDP:
+			if dataLen < offset+4 {
+				return ErrIPv6PacketTooShort
+			}
+
+			fp.Protocol = uint8(proto)
+			if incoming {
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			} else {
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			}
+
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolIPv6Fragment:
+			// Fragment header is 8 bytes, need at least offset+4 to read the offset field
+			if dataLen < offset+8 {
+				return ErrIPv6PacketTooShort
+			}
+
+			// Check if this is the first fragment
+			fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
+			if fragmentOffset != 0 {
+				// Non-first fragment, use what we have now and stop processing
+				fp.Protocol = data[offset]
+				fp.Fragment = true
+				fp.RemotePort = 0
+				fp.LocalPort = 0
+				return nil
+			}
+
+			// The next loop should be the transport layer since we are the first fragment
+			next = 8 // Fragment headers are always 8 bytes
+
+		case layers.IPProtocolAH:
+			// Auth headers, used by IPSec, have a different meaning for header length
+			if dataLen < offset+1 {
+				break
+			}
+
+			next = int(data[offset+1]+2) << 2
+
+		default:
+			// Normal ipv6 header length processing
+			if dataLen < offset+1 {
+				break
+			}
+
+			next = int(data[offset+1]+1) << 3
+		}
+
+		if next <= 0 {
+			// Safety check, each ipv6 header has to be at least 8 bytes
+			next = 8
+		}
+
+		protoAt = offset
+		offset = offset + next
+	}
+
+	return ErrIPv6CouldNotFindPayload
+}
+
+func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
+	// Do we at least have an ipv4 header worth of data?
+	if len(data) < ipv4.HeaderLen {
+		return ErrIPv4PacketTooShort
 	}
 	}
 
 
 	// Adjust our start position based on the advertised ip header length
 	// Adjust our start position based on the advertised ip header length
 	ihl := int(data[0]&0x0f) << 2
 	ihl := int(data[0]&0x0f) << 2
 
 
-	// Well formed ip header length?
+	// Well-formed ip header length?
 	if ihl < ipv4.HeaderLen {
 	if ihl < ipv4.HeaderLen {
-		return fmt.Errorf("packet had an invalid header length: %v", ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 	}
 
 
 	// Check if this is the second or further fragment of a fragmented packet.
 	// Check if this is the second or further fragment of a fragmented packet.
@@ -344,14 +430,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 		minLen += minFwPacketLen
 		minLen += minFwPacketLen
 	}
 	}
 	if len(data) < minLen {
 	if len(data) < minLen {
-		return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 	}
 
 
 	// Firewall packets are locally oriented
 	// Firewall packets are locally oriented
 	if incoming {
 	if incoming {
-		//TODO: IPV6-WORK
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
@@ -360,9 +445,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 		}
 	} else {
 	} else {
-		//TODO: IPV6-WORK
-		fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
@@ -397,8 +481,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
-		//TODO: maybe after build 64 is out? 06/14/2018 - NB
-		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
 		return false
 		return false
 	}
 	}
 
 
@@ -445,9 +527,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
 func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
 
-	//TODO: this should be a signed message so we can trust that we should drop the index
 	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
 	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
-	f.outside.WriteTo(b, endpoint)
+	_ = f.outside.WriteTo(b, endpoint)
 	if f.l.Level >= logrus.DebugLevel {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", index).
 		f.l.WithField("index", index).
 			WithField("udpAddr", endpoint).
 			WithField("udpAddr", endpoint).
@@ -481,65 +562,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 	// We also delete it from pending hostmap to allow for fast reconnect.
 	// We also delete it from pending hostmap to allow for fast reconnect.
 	f.handshakeManager.DeleteHostInfo(hostinfo)
 	f.handshakeManager.DeleteHostInfo(hostinfo)
 }
 }
-
-/*
-func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
-	if ci.eKey != nil {
-		//TODO: log error?
-		return
-	}
-
-	msg, err := proto.Marshal(meta)
-	if err != nil {
-		l.Debugln("failed to encode header")
-	}
-
-	c := ci.messageCounter
-	b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
-	ci.messageCounter++
-
-	msg := ci.eKey.EncryptDanger(b, nil, msg, c)
-	//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
-	f.outside.WriteTo(msg, endpoint)
-}
-*/
-
-func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
-	pk := h.PeerStatic()
-
-	if pk == nil {
-		return nil, errors.New("no peer static key was present")
-	}
-
-	if rawCertBytes == nil {
-		return nil, errors.New("provided payload was empty")
-	}
-
-	r := &cert.RawNebulaCertificate{}
-	err := proto.Unmarshal(rawCertBytes, r)
-	if err != nil {
-		return nil, fmt.Errorf("error unmarshaling cert: %s", err)
-	}
-
-	// If the Details are nil, just exit to avoid crashing
-	if r.Details == nil {
-		return nil, fmt.Errorf("certificate did not contain any details")
-	}
-
-	r.Details.PublicKey = pk
-	recombined, err := proto.Marshal(r)
-	if err != nil {
-		return nil, fmt.Errorf("error while recombining certificate: %s", err)
-	}
-
-	c, _ := cert.UnmarshalNebulaCertificate(recombined)
-	isValid, err := c.Verify(time.Now(), caPool)
-	if err != nil {
-		return c, fmt.Errorf("certificate validation failed: %s", err)
-	} else if !isValid {
-		// This case should never happen but here's to defensive programming!
-		return c, errors.New("certificate validation failed but did not return an error")
-	}
-
-	return c, nil
-}

+ 525 - 17
outside_test.go

@@ -1,10 +1,15 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"bytes"
+	"encoding/binary"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 
 
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/layers"
+
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
@@ -13,9 +18,15 @@ import (
 func Test_newPacket(t *testing.T) {
 func Test_newPacket(t *testing.T) {
 	p := &firewall.Packet{}
 	p := &firewall.Packet{}
 
 
-	// length fail
-	err := newPacket([]byte{0, 1}, true, p)
-	assert.EqualError(t, err, "packet is less than 20 bytes")
+	// length fails
+	err := newPacket([]byte{}, true, p)
+	assert.ErrorIs(t, err, ErrPacketTooShort)
+
+	err = newPacket([]byte{0x40}, true, p)
+	assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
+
+	err = newPacket([]byte{0x60}, true, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 
 	// length fail with ip options
 	// length fail with ip options
 	h := ipv4.Header{
 	h := ipv4.Header{
@@ -28,16 +39,15 @@ func Test_newPacket(t *testing.T) {
 
 
 	b, _ := h.Marshal()
 	b, _ := h.Marshal()
 	err = newPacket(b, true, p)
 	err = newPacket(b, true, p)
-
-	assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 
 	// not an ipv4 packet
 	// not an ipv4 packet
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet is not ipv4, type: 0")
+	assert.ErrorIs(t, err, ErrUnknownIPVersion)
 
 
 	// invalid ihl
 	// invalid ihl
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet had an invalid header length: 8")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 
 	// account for variable ip header length - incoming
 	// account for variable ip header length - incoming
 	h = ipv4.Header{
 	h = ipv4.Header{
@@ -54,11 +64,12 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, true, p)
 	err = newPacket(b, true, p)
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemotePort, uint16(3))
-	assert.Equal(t, p.LocalPort, uint16(4))
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
+	assert.Equal(t, uint16(3), p.RemotePort)
+	assert.Equal(t, uint16(4), p.LocalPort)
+	assert.False(t, p.Fragment)
 
 
 	// account for variable ip header length - outgoing
 	// account for variable ip header length - outgoing
 	h = ipv4.Header{
 	h = ipv4.Header{
@@ -75,9 +86,506 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, false, p)
 	err = newPacket(b, false, p)
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemotePort, uint16(6))
-	assert.Equal(t, p.LocalPort, uint16(5))
+	assert.Equal(t, uint8(2), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
+	assert.Equal(t, uint16(6), p.RemotePort)
+	assert.Equal(t, uint16(5), p.LocalPort)
+	assert.False(t, p.Fragment)
+}
+
+func Test_newPacket_v6(t *testing.T) {
+	p := &firewall.Packet{}
+
+	// invalid ipv6
+	ip := layers.IPv6{
+		Version:  6,
+		HopLimit: 128,
+		SrcIP:    net.IPv6linklocalallrouters,
+		DstIP:    net.IPv6linklocalallnodes,
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opt := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       false,
+	}
+	err := gopacket.SerializeLayers(buffer, opt, &ip)
+	assert.NoError(t, err)
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good ICMP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolICMPv6,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	icmp := layers.ICMPv6{}
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
+	if err != nil {
+		panic(err)
+	}
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good ESP packet
+	b := buffer.Bytes()
+	b[6] = byte(layers.IPProtocolESP)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good None packet
+	b = buffer.Bytes()
+	b[6] = byte(layers.IPProtocolNoNextHeader)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// An unknown protocol packet
+	b = buffer.Bytes()
+	b[6] = 255 // 255 is a reserved protocol number
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good UDP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: firewall.ProtoUDP,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+	err = udp.SetNetworkLayerForChecksum(&ip)
+	assert.NoError(t, err)
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
+	if err != nil {
+		panic(err)
+	}
+	b = buffer.Bytes()
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short UDP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good TCP packet
+	b[6] = byte(layers.IPProtocolTCP)
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short TCP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good UDP packet with an AH header
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolAH,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	ah := layers.IPSecAH{
+		AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
+	}
+	ah.NextHeader = layers.IPProtocolUDP
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opt)
+	if err != nil {
+		panic(err)
+	}
+
+	b = buffer.Bytes()
+	ahb := serializeAH(&ah)
+	b = append(b, ahb...)
+	b = append(b, udpHeader...)
+
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Invalid AH header
+	b = buffer.Bytes()
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+}
+
+func Test_newPacket_ipv6Fragment(t *testing.T) {
+	p := &firewall.Packet{}
+
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	// First fragment
+	fragHeader1 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: true,
+		FixLengths:       true,
+	}
+
+	err := ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader1...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test first fragment incoming
+	err = newPacket(firstFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Test first fragment outgoing
+	err = newPacket(firstFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Second fragment
+	fragHeader2 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0xb9,                        // Fragment Offset high byte (185)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader2...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test second fragment incoming
+	err = newPacket(secondFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.True(t, p.Fragment)
+
+	// Test second fragment outgoing
+	err = newPacket(secondFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.True(t, p.Fragment)
+
+	// Too short of a fragment packet
+	err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+}
+
+func BenchmarkParseV6(b *testing.B) {
+	// Regular UDP packet
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolUDP,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := &layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       true,
+	}
+
+	err := gopacket.SerializeLayers(buffer, opts, ip, udp)
+	if err != nil {
+		b.Fatal(err)
+	}
+	normalPacket := buffer.Bytes()
+
+	// First Fragment packet
+	ipFrag := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	fragHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x7b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Second Fragment packet
+	fragHeader[2] = 0xb9 // offset 185
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	fp := &firewall.Packet{}
+
+	b.Run("Normal", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(normalPacket, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("FirstFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(firstFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("SecondFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(secondFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	// Evil packet
+	evilPacket := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6HopByHop,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	hopHeader := []byte{
+		uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
+		0x00,                                 // Length
+		0x00, 0x00,                           // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	lastHopHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Length
+		0x00, 0x00,                  // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	buffer.Clear()
+	err = evilPacket.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	evilBytes := buffer.Bytes()
+	for i := 0; i < 200; i++ {
+		evilBytes = append(evilBytes, hopHeader...)
+	}
+	evilBytes = append(evilBytes, lastHopHeader...)
+	evilBytes = append(evilBytes, udpHeader...)
+	evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	b.Run("200 HopByHop headers", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(evilBytes, false, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+}
+
+// Ensure authentication data is a multiple of 8 bytes by padding if necessary
+func padAuthData(authData []byte) []byte {
+	// Length of Authentication Data must be a multiple of 8 bytes
+	paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
+	if paddingLength > 0 {
+		authData = append(authData, make([]byte, paddingLength)...)
+	}
+	return authData
+}
+
+// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
+func serializeAH(ah *layers.IPSecAH) []byte {
+	buf := new(bytes.Buffer)
+
+	// Ensure Authentication Data is a multiple of 8 bytes
+	ah.AuthenticationData = padAuthData(ah.AuthenticationData)
+	// Calculate Payload Length (in 32-bit words, minus 2)
+	payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
+
+	// Serialize fields
+	if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
+		panic(err)
+	}
+	if len(ah.AuthenticationData) > 0 {
+		if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
+			panic(err)
+		}
+	}
+
+	return buf.Bytes()
 }
 }

+ 1 - 1
overlay/device.go

@@ -8,7 +8,7 @@ import (
 type Device interface {
 type Device interface {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	Activate() error
 	Activate() error
-	Cidr() netip.Prefix
+	Networks() []netip.Prefix
 	Name() string
 	Name() string
 	RouteFor(netip.Addr) netip.Addr
 	RouteFor(netip.Addr) netip.Addr
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 	NewMultiQueueReader() (io.ReadWriteCloser, error)

+ 22 - 12
overlay/route.go

@@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
 	return routeTree, nil
 	return routeTree, nil
 }
 }
 
 
-func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.routes")
 	r := c.Get("tun.routes")
@@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 		}
 		}
 
 
-		if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
+		found := false
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
+				found = true
+				break
+			}
+		}
+
+		if !found {
 			return nil, fmt.Errorf(
 			return nil, fmt.Errorf(
-				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
+				"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
 				i+1,
 				i+1,
 				r.Cidr.String(),
 				r.Cidr.String(),
-				network.String(),
+				networks,
 			)
 			)
 		}
 		}
 
 
@@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	return routes, nil
 	return routes, nil
 }
 }
 
 
-func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.unsafe_routes")
 	r := c.Get("tun.unsafe_routes")
@@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 		}
 		}
 
 
-		if network.Contains(r.Cidr.Addr()) {
-			return nil, fmt.Errorf(
-				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
-				i+1,
-				r.Cidr.String(),
-				network.String(),
-			)
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) {
+				return nil, fmt.Errorf(
+					"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
+					i+1,
+					r.Cidr.String(),
+					network.String(),
+				)
+			}
 		}
 		}
 
 
 		routes[i] = r
 		routes[i] = r

+ 39 - 33
overlay/route_test.go

@@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	// test no routes config
 	// test no routes config
-	routes, err := parseRoutes(c, n)
+	routes, err := parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// not an array
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.routes is not an array")
 	assert.EqualError(t, err, "tun.routes is not an array")
 
 
 	// no routes
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// weird route
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 
 	// no mtu
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 
 	// bad mtu
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 
 	// low mtu
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 
 	// unparsable route
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
 
 
 	// above network range
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
+
+	// Not in multiple ranges
+	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
+	routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
+	assert.Nil(t, routes)
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
 
 
 	// happy case
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
 	}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 2)
 	assert.Len(t, routes, 2)
 
 
@@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	// test no routes config
 	// test no routes config
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// not an array
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 
 	// no routes
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// weird route
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 
 	// no via
 	// no via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
 
@@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		127, false, nil, 1.0, []string{"1", "2"},
 		127, false, nil, 1.0, []string{"1", "2"},
 	} {
 	} {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
-		routes, err = parseUnsafeRoutes(c, n)
+		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
 		assert.Nil(t, routes)
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 	}
 	}
 
 
 	// unparsable via
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 
 	// unparsable route
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 
 	// within network range
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	// above network range
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	// no mtu
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Equal(t, 0, routes[0].MTU)
 	assert.Equal(t, 0, routes[0].MTU)
 
 
 	// bad mtu
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 
 	// low mtu
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 
 	// bad install
 	// bad install
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 
 
@@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
 	}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 4)
 	assert.Len(t, routes, 4)
 
 
@@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 	}}
 	}}
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.Len(t, routes, 2)
 	assert.Len(t, routes, 2)
 	routeTree, err := makeRouteTree(l, routes, true)
 	routeTree, err := makeRouteTree(l, routes, true)

+ 9 - 9
overlay/tun.go

@@ -11,36 +11,36 @@ import (
 const DefaultMTU = 1300
 const DefaultMTU = 1300
 
 
 // TODO: We may be able to remove routines
 // TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
 
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
 	switch {
 	switch {
 	case c.GetBool("tun.disabled", false):
 	case c.GetBool("tun.disabled", false):
-		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
+		tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 		return tun, nil
 
 
 	default:
 	default:
-		return newTun(c, l, tunCidr, routines > 1)
+		return newTun(c, l, vpnNetworks, routines > 1)
 	}
 	}
 }
 }
 
 
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
-		return newTunFromFd(c, l, *fd, tunCidr)
+	return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+		return newTunFromFd(c, l, *fd, vpnNetworks)
 	}
 	}
 }
 }
 
 
-func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 		return false, nil, nil
 		return false, nil, nil
 	}
 	}
 
 
-	routes, err := parseRoutes(c, cidr)
+	routes, err := parseRoutes(c, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 		return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 	}
 	}
 
 
-	unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
+	unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 		return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 	}

+ 11 - 11
overlay/tun_android.go

@@ -18,14 +18,14 @@ import (
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	fd        int
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	fd          int
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              deviceFd,
 		fd:              deviceFd,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		l:               l,
 		l:               l,
 	}
 	}
 
 
@@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 }
 
 
@@ -66,7 +66,7 @@ func (t tun) Activate() error {
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 208 - 215
overlay/tun_darwin.go

@@ -24,56 +24,62 @@ import (
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	Device     string
-	cidr       netip.Prefix
-	DefaultMTU int
-	Routes     atomic.Pointer[[]Route]
-	routeTree  atomic.Pointer[bart.Table[netip.Addr]]
-	linkAddr   *netroute.LinkAddr
-	l          *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	DefaultMTU  int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	linkAddr    *netroute.LinkAddr
+	l           *logrus.Logger
 
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	out []byte
 	out []byte
 }
 }
 
 
-type sockaddrCtl struct {
-	scLen      uint8
-	scFamily   uint8
-	ssSysaddr  uint16
-	scID       uint32
-	scUnit     uint32
-	scReserved [5]uint32
-}
-
 type ifReq struct {
 type ifReq struct {
-	Name  [16]byte
+	Name  [unix.IFNAMSIZ]byte
 	Flags uint16
 	Flags uint16
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-var sockaddrCtlSize uintptr = 32
-
 const (
 const (
-	_SYSPROTO_CONTROL = 2              //define SYSPROTO_CONTROL 2 /* kernel control protocol */
-	_AF_SYS_CONTROL   = 2              //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
-	_PF_SYSTEM        = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
-	_CTLIOCGINFO      = 3227799043     //#define CTLIOCGINFO     _IOWR('N', 3, struct ctl_info)
-	utunControlName   = "com.apple.net.utun_control"
+	_SIOCAIFADDR_IN6 = 2155899162
+	_UTUN_OPT_IFNAME = 2
+	_IN6_IFF_NODAD   = 0x0020
+	_IN6_IFF_SECURED = 0x0400
+	utunControlName  = "com.apple.net.utun_control"
 )
 )
 
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 type ifreqMTU struct {
 	Name [16]byte
 	Name [16]byte
 	MTU  int32
 	MTU  int32
 	pad  [8]byte
 	pad  [8]byte
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+type addrLifetime struct {
+	Expire    float64
+	Preferred float64
+	Vltime    uint32
+	Pltime    uint32
+}
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+}
+
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	name := c.GetString("tun.dev", "")
 	name := c.GetString("tun.dev", "")
 	ifIndex := -1
 	ifIndex := -1
 	if name != "" && name != "utun" {
 	if name != "" && name != "utun" {
@@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 		}
 		}
 	}
 	}
 
 
-	fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
+	fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("system socket: %v", err)
 		return nil, fmt.Errorf("system socket: %v", err)
 	}
 	}
 
 
-	var ctlInfo = &struct {
-		ctlID   uint32
-		ctlName [96]byte
-	}{}
+	var ctlInfo = &unix.CtlInfo{}
+	copy(ctlInfo.Name[:], utunControlName)
 
 
-	copy(ctlInfo.ctlName[:], utunControlName)
-
-	err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
+	err = unix.IoctlCtlInfo(fd, ctlInfo)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
 		return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
 	}
 	}
 
 
-	sc := sockaddrCtl{
-		scLen:     uint8(sockaddrCtlSize),
-		scFamily:  unix.AF_SYSTEM,
-		ssSysaddr: _AF_SYS_CONTROL,
-		scID:      ctlInfo.ctlID,
-		scUnit:    uint32(ifIndex) + 1,
-	}
-
-	_, _, errno := unix.RawSyscall(
-		unix.SYS_CONNECT,
-		uintptr(fd),
-		uintptr(unsafe.Pointer(&sc)),
-		sockaddrCtlSize,
-	)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
+	err = unix.Connect(fd, &unix.SockaddrCtl{
+		ID:   ctlInfo.Id,
+		Unit: uint32(ifIndex) + 1,
+	})
+	if err != nil {
+		return nil, fmt.Errorf("SYS_CONNECT: %v", err)
 	}
 	}
 
 
-	var ifName struct {
-		name [16]byte
-	}
-	ifNameSize := uintptr(len(ifName.name))
-	_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
-		2, // SYSPROTO_CONTROL
-		2, // UTUN_OPT_IFNAME
-		uintptr(unsafe.Pointer(&ifName)),
-		uintptr(unsafe.Pointer(&ifNameSize)), 0)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
+	name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
+	if err != nil {
+		return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
 	}
 	}
-	name = string(ifName.name[:ifNameSize-1])
 
 
-	err = syscall.SetNonblock(fd, true)
+	err = unix.SetNonblock(fd, true)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("SetNonblock: %v", err)
 		return nil, fmt.Errorf("SetNonblock: %v", err)
 	}
 	}
 
 
-	file := os.NewFile(uintptr(fd), "")
-
 	t := &tun{
 	t := &tun{
-		ReadWriteCloser: file,
+		ReadWriteCloser: os.NewFile(uintptr(fd), ""),
 		Device:          name,
 		Device:          name,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		DefaultMTU:      c.GetInt("tun.mtu", DefaultMTU),
 		DefaultMTU:      c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 }
 
 
@@ -186,16 +167,6 @@ func (t *tun) Close() error {
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 	devName := t.deviceBytes()
 
 
-	var addr, mask [4]byte
-
-	if !t.cidr.Addr().Is4() {
-		//TODO: IPV6-WORK
-		panic("need ipv6")
-	}
-
-	addr = t.cidr.Addr().As4()
-	copy(mask[:], prefixToMask(t.cidr))
-
 	s, err := unix.Socket(
 	s, err := unix.Socket(
 		unix.AF_INET,
 		unix.AF_INET,
 		unix.SOCK_DGRAM,
 		unix.SOCK_DGRAM,
@@ -208,66 +179,18 @@ func (t *tun) Activate() error {
 
 
 	fd := uintptr(s)
 	fd := uintptr(s)
 
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
-	// Set the device name
-	ifrf := ifReq{Name: devName}
-	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to set tun device name: %s", err)
-	}
-
 	// Set the MTU on the device
 	// Set the MTU on the device
 	ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
 	ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 		return fmt.Errorf("failed to set tun mtu: %v", err)
 		return fmt.Errorf("failed to set tun mtu: %v", err)
 	}
 	}
 
 
-	/*
-		// Set the transmit queue length
-		ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
-		if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
-			// If we can't set the queue length nebula will still work but it may lead to packet loss
-			l.WithError(err).Error("Failed to set tun tx queue length")
-		}
-	*/
-
-	// Bring up the interface
-	ifrf.Flags = ifrf.Flags | unix.IFF_UP
-	if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to bring the tun device up: %s", err)
-	}
-
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	// Get the device flags
+	ifrf := ifReq{Name: devName}
+	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+		return fmt.Errorf("failed to get tun flags: %s", err)
 	}
 	}
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
 
 
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	linkAddr, err := getLinkAddr(t.Device)
 	linkAddr, err := getLinkAddr(t.Device)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -277,14 +200,18 @@ func (t *tun) Activate() error {
 	}
 	}
 	t.linkAddr = linkAddr
 	t.linkAddr = linkAddr
 
 
-	copy(routeAddr.IP[:], addr[:])
-	copy(maskAddr.IP[:], mask[:])
-	err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
-	if err != nil {
-		if errors.Is(err, unix.EEXIST) {
-			err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
+	for _, network := range t.vpnNetworks {
+		if network.Addr().Is4() {
+			err = t.activate4(network)
+			if err != nil {
+				return err
+			}
+		} else {
+			err = t.activate6(network)
+			if err != nil {
+				return err
+			}
 		}
 		}
-		return err
 	}
 	}
 
 
 	// Run the interface
 	// Run the interface
@@ -297,8 +224,89 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) activate4(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias4{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		DstAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		MaskAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(network).As4(),
+		},
+	}
+
+	if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun v4 address: %s", err)
+	}
+
+	err = addRoute(network, t.linkAddr)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (t *tun) activate6(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET6,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias6{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   network.Addr().As16(),
+		},
+		PrefixMask: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(network).As16(),
+		},
+		Lifetime: addrLifetime{
+			// never expires
+			Vltime: 0xffffffff,
+			Pltime: 0xffffffff,
+		},
+		//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
+		Flags: _IN6_IFF_NODAD,
+	}
+
+	if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun address: %s", err)
+	}
+
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 }
 }
 
 
 // Get the LinkAddr for the interface of the given name
 // Get the LinkAddr for the interface of the given name
-// TODO: Is there an easier way to fetch this when we create the interface?
+// Is there an easier way to fetch this when we create the interface?
 // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
 // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
 func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 	rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
 	rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
@@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 }
 }
 
 
 func (t *tun) addRoutes(logErrors bool) error {
 func (t *tun) addRoutes(logErrors bool) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 	for _, r := range routes {
 		if !r.Via.IsValid() || !r.Install {
 		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
 
 
-		if !r.Cidr.Addr().Is4() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		//TODO: we could avoid the copy
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := addRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 		if err != nil {
 			if errors.Is(err, unix.EEXIST) {
 			if errors.Is(err, unix.EEXIST) {
 				t.l.WithField("route", r.Cidr).
 				t.l.WithField("route", r.Cidr).
@@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
 }
 }
 
 
 func (t *tun) removeRoutes(routes []Route) error {
 func (t *tun) removeRoutes(routes []Route) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
-
 	for _, r := range routes {
 	for _, r := range routes {
 		if !r.Install {
 		if !r.Install {
 			continue
 			continue
 		}
 		}
 
 
-		if r.Cidr.Addr().Is6() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := delRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 		} else {
@@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 	return nil
 }
 }
 
 
-func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_ADD,
 		Type:    unix.RTM_ADD,
 		Flags:   unix.RTF_UP,
 		Flags:   unix.RTF_UP,
 		Seq:     1,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 	}
 
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
 	}
+
 	_, err = unix.Write(sock, data[:])
 	_, err = unix.Write(sock, data[:])
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
 		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 	return nil
 	return nil
 }
 }
 
 
-func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_DELETE,
 		Type:    unix.RTM_DELETE,
 		Seq:     1,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 	}
 
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
 	}
@@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 }
 }
 
 
 func (t *tun) Read(to []byte) (int, error) {
 func (t *tun) Read(to []byte) (int, error) {
-
 	buf := make([]byte, len(to)+4)
 	buf := make([]byte, len(to)+4)
 
 
 	n, err := t.ReadWriteCloser.Read(buf)
 	n, err := t.ReadWriteCloser.Read(buf)
@@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
 	return n - 4, err
 	return n - 4, err
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {
@@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
 }
 
 
-func prefixToMask(prefix netip.Prefix) []byte {
+func prefixToMask(prefix netip.Prefix) netip.Addr {
 	pLen := 128
 	pLen := 128
 	if prefix.Addr().Is4() {
 	if prefix.Addr().Is4() {
 		pLen = 32
 		pLen = 32
 	}
 	}
-	return net.CIDRMask(prefix.Bits(), pLen)
+
+	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
+	return addr
 }
 }

+ 8 - 8
overlay/tun_disabled.go

@@ -12,8 +12,8 @@ import (
 )
 )
 
 
 type disabledTun struct {
 type disabledTun struct {
-	read chan []byte
-	cidr netip.Prefix
+	read        chan []byte
+	vpnNetworks []netip.Prefix
 
 
 	// Track these metrics since we don't have the tun device to do it for us
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
 	tx metrics.Counter
@@ -21,11 +21,11 @@ type disabledTun struct {
 	l  *logrus.Logger
 	l  *logrus.Logger
 }
 }
 
 
-func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
 	tun := &disabledTun{
-		cidr: cidr,
-		read: make(chan []byte, queueLen),
-		l:    l,
+		vpnNetworks: vpnNetworks,
+		read:        make(chan []byte, queueLen),
+		l:           l,
 	}
 	}
 
 
 	if metricsEnabled {
 	if metricsEnabled {
@@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
 	return netip.Addr{}
 	return netip.Addr{}
 }
 }
 
 
-func (t *disabledTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *disabledTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (*disabledTun) Name() string {
 func (*disabledTun) Name() string {

+ 25 - 15
overlay/tun_freebsd.go

@@ -46,12 +46,12 @@ type ifreqDestroy struct {
 }
 }
 
 
 type tun struct {
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 }
 }
@@ -78,11 +78,11 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
 	// Try to open existing tun device
 	var file *os.File
 	var file *os.File
 	var err error
 	var err error
@@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		Device:          deviceName,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 	return t, nil
 }
 }
 
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	var err error
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -195,8 +195,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 	return r
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 10 - 10
overlay/tun_ios.go

@@ -21,20 +21,20 @@ import (
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 }
 
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	t := &tun{
 	t := &tun{
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		ReadWriteCloser: &tunReadCloser{f: file},
 		ReadWriteCloser: &tunReadCloser{f: file},
 		l:               l,
 		l:               l,
 	}
 	}
@@ -59,7 +59,7 @@ func (t *tun) Activate() error {
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
 	return tr.f.Close()
 	return tr.f.Close()
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 120 - 78
overlay/tun_linux.go

@@ -11,6 +11,7 @@ import (
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
+	"time"
 	"unsafe"
 	"unsafe"
 
 
 	"github.com/gaissmai/bart"
 	"github.com/gaissmai/bart"
@@ -25,7 +26,7 @@ type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	fd          int
 	fd          int
 	Device      string
 	Device      string
-	cidr        netip.Prefix
+	vpnNetworks []netip.Prefix
 	MaxMTU      int
 	MaxMTU      int
 	DefaultMTU  int
 	DefaultMTU  int
 	TXQueueLen  int
 	TXQueueLen  int
@@ -40,18 +41,16 @@ type tun struct {
 	l *logrus.Logger
 	l *logrus.Logger
 }
 }
 
 
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
+
 type ifReq struct {
 type ifReq struct {
 	Name  [16]byte
 	Name  [16]byte
 	Flags uint16
 	Flags uint16
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 type ifreqMTU struct {
 	Name [16]byte
 	Name [16]byte
 	MTU  int32
 	MTU  int32
@@ -64,10 +63,10 @@ type ifreqQLEN struct {
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 	if err != nil {
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	name := strings.Trim(string(req.Name[:]), "\x00")
 	name := strings.Trim(string(req.Name[:]), "\x00")
 
 
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
 		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
 		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
 		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
 		l:               l,
 		l:               l,
@@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
 		}
 		}
 
 
 		if oldDefaultMTU != newDefaultMTU {
 		if oldDefaultMTU != newDefaultMTU {
-			err := t.setDefaultRoute()
-			if err != nil {
-				t.l.Warn(err)
-			} else {
-				t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+			for i := range t.vpnNetworks {
+				err := t.setDefaultRoute(t.vpnNetworks[i])
+				if err != nil {
+					t.l.Warn(err)
+				} else {
+					t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+				}
 			}
 			}
 		}
 		}
 
 
@@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 
 
 func (t *tun) Write(b []byte) (int, error) {
 func (t *tun) Write(b []byte) (int, error) {
 	var nn int
 	var nn int
-	max := len(b)
+	maximum := len(b)
 
 
 	for {
 	for {
-		n, err := unix.Write(t.fd, b[nn:max])
+		n, err := unix.Write(t.fd, b[nn:maximum])
 		if n > 0 {
 		if n > 0 {
 			nn += n
 			nn += n
 		}
 		}
@@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
+func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
+	for i := range al {
+		if al[i].Equal(x) {
+			return true
+		}
+	}
+	return false
+}
+
+// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
+func (t *tun) addIPs(link netlink.Link) error {
+	newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
+	for i := range t.vpnNetworks {
+		newAddrs[i] = &netlink.Addr{
+			IPNet: &net.IPNet{
+				IP:   t.vpnNetworks[i].Addr().AsSlice(),
+				Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
+			},
+			Label: t.vpnNetworks[i].Addr().Zone(),
+		}
+	}
+
+	//add all new addresses
+	for i := range newAddrs {
+		//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
+		//AddrReplace still adds new IPs, but if their properties change it will change them as well
+		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
+			return err
+		}
+	}
+
+	//iterate over remainder, remove whoever shouldn't be there
+	al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
+	if err != nil {
+		return fmt.Errorf("failed to get tun address list: %s", err)
+	}
+
+	for i := range al {
+		if hasNetlinkAddr(newAddrs, al[i]) {
+			continue
+		}
+		err = netlink.AddrDel(link, &al[i])
+		if err != nil {
+			t.l.WithError(err).Error("failed to remove address from tun address list")
+		} else {
+			t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
+		}
+	}
+
+	return nil
+}
+
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 	devName := t.deviceBytes()
 
 
@@ -272,15 +325,8 @@ func (t *tun) Activate() error {
 		t.watchRoutes()
 		t.watchRoutes()
 	}
 	}
 
 
-	var addr, mask [4]byte
-
-	//TODO: IPV6-WORK
-	addr = t.cidr.Addr().As4()
-	tmask := net.CIDRMask(t.cidr.Bits(), 32)
-	copy(mask[:], tmask)
-
 	s, err := unix.Socket(
 	s, err := unix.Socket(
-		unix.AF_INET,
+		unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
 		unix.SOCK_DGRAM,
 		unix.SOCK_DGRAM,
 		unix.IPPROTO_IP,
 		unix.IPPROTO_IP,
 	)
 	)
@@ -289,31 +335,19 @@ func (t *tun) Activate() error {
 	}
 	}
 	t.ioctlFd = uintptr(s)
 	t.ioctlFd = uintptr(s)
 
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
 	// Set the device name
 	// Set the device name
 	ifrf := ifReq{Name: devName}
 	ifrf := ifReq{Name: devName}
 	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to set tun device name: %s", err)
 		return fmt.Errorf("failed to set tun device name: %s", err)
 	}
 	}
 
 
+	link, err := netlink.LinkByName(t.Device)
+	if err != nil {
+		return fmt.Errorf("failed to get tun device link: %s", err)
+	}
+
+	t.deviceIndex = link.Attrs().Index
+
 	// Setup our default MTU
 	// Setup our default MTU
 	t.setMTU()
 	t.setMTU()
 
 
@@ -324,20 +358,21 @@ func (t *tun) Activate() error {
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 	}
 
 
+	if err = t.addIPs(link); err != nil {
+		return err
+	}
+
 	// Bring up the interface
 	// Bring up the interface
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP
 	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to bring the tun device up: %s", err)
 		return fmt.Errorf("failed to bring the tun device up: %s", err)
 	}
 	}
 
 
-	link, err := netlink.LinkByName(t.Device)
-	if err != nil {
-		return fmt.Errorf("failed to get tun device link: %s", err)
-	}
-	t.deviceIndex = link.Attrs().Index
-
-	if err = t.setDefaultRoute(); err != nil {
-		return err
+	//set route MTU
+	for i := range t.vpnNetworks {
+		if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
+			return fmt.Errorf("failed to set default route MTU: %w", err)
+		}
 	}
 	}
 
 
 	// Set the routes
 	// Set the routes
@@ -363,12 +398,10 @@ func (t *tun) setMTU() {
 	}
 	}
 }
 }
 
 
-func (t *tun) setDefaultRoute() error {
-	// Default route
-
+func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
 	dr := &net.IPNet{
 	dr := &net.IPNet{
-		IP:   t.cidr.Masked().Addr().AsSlice(),
-		Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+		IP:   cidr.Masked().Addr().AsSlice(),
+		Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
 	}
 	}
 
 
 	nr := netlink.Route{
 	nr := netlink.Route{
@@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error {
 		MTU:       t.DefaultMTU,
 		MTU:       t.DefaultMTU,
 		AdvMSS:    t.advMSS(Route{}),
 		AdvMSS:    t.advMSS(Route{}),
 		Scope:     unix.RT_SCOPE_LINK,
 		Scope:     unix.RT_SCOPE_LINK,
-		Src:       net.IP(t.cidr.Addr().AsSlice()),
+		Src:       net.IP(cidr.Addr().AsSlice()),
 		Protocol:  unix.RTPROT_KERNEL,
 		Protocol:  unix.RTPROT_KERNEL,
 		Table:     unix.RT_TABLE_MAIN,
 		Table:     unix.RT_TABLE_MAIN,
 		Type:      unix.RTN_UNICAST,
 		Type:      unix.RTN_UNICAST,
 	}
 	}
 	err := netlink.RouteReplace(&nr)
 	err := netlink.RouteReplace(&nr)
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
+		t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
+		//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
+		for i := 0; i < 2; i++ {
+			time.Sleep(100 * time.Millisecond)
+			err = netlink.RouteReplace(&nr)
+			if err == nil {
+				break
+			} else {
+				t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
+			}
+		}
+		if err != nil {
+			return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
+		}
 	}
 	}
 
 
 	return nil
 	return nil
@@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) {
 	}
 	}
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
 func (t *tun) Name() string {
 func (t *tun) Name() string {
 	return t.Device
 	return t.Device
 }
 }
@@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 		return
 	}
 	}
 
 
-	//TODO: IPV6-WORK what if not ok?
 	gwAddr, ok := netip.AddrFromSlice(r.Gw)
 	gwAddr, ok := netip.AddrFromSlice(r.Gw)
 	if !ok {
 	if !ok {
 		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
 		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
@@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	}
 	}
 
 
 	gwAddr = gwAddr.Unmap()
 	gwAddr = gwAddr.Unmap()
-	if !t.cidr.Contains(gwAddr) {
-		// Gateway isn't in our overlay network, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
-		return
+	withinNetworks := false
+	for i := range t.vpnNetworks {
+		if t.vpnNetworks[i].Contains(gwAddr) {
+			withinNetworks = true
+			break
+		}
 	}
 	}
-
-	if x := r.Dst.IP.To4(); x == nil {
-		// Nebula only handles ipv4 on the overlay currently
-		t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
+	if !withinNetworks {
+		// Gateway isn't in our overlay network, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
 		return
 		return
 	}
 	}
 
 
@@ -563,11 +605,11 @@ func (t *tun) Close() error {
 	}
 	}
 
 
 	if t.ReadWriteCloser != nil {
 	if t.ReadWriteCloser != nil {
-		t.ReadWriteCloser.Close()
+		_ = t.ReadWriteCloser.Close()
 	}
 	}
 
 
 	if t.ioctlFd > 0 {
 	if t.ioctlFd > 0 {
-		os.NewFile(t.ioctlFd, "ioctlFd").Close()
+		_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
 	}
 	}
 
 
 	return nil
 	return nil

+ 28 - 17
overlay/tun_netbsd.go

@@ -27,12 +27,12 @@ type ifreqDestroy struct {
 }
 }
 
 
 type tun struct {
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 }
 }
@@ -58,13 +58,13 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 }
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
 	// Try to open tun device
 	var file *os.File
 	var file *os.File
 	var err error
 	var err error
@@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		Device:          deviceName,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 	return t, nil
 }
 }
 
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	var err error
 
 
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -130,8 +130,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 	return r
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {
@@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

+ 29 - 19
overlay/tun_openbsd.go

@@ -21,12 +21,12 @@ import (
 )
 )
 
 
 type tun struct {
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 
 
@@ -42,13 +42,13 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 }
 }
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
 	if deviceName == "" {
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		Device:          deviceName,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	var err error
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -138,7 +138,7 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -148,6 +148,16 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
@@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 		if !r.Install {
 		if !r.Install {
 			continue
 			continue
 		}
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 17 - 17
overlay/tun_tester.go

@@ -16,19 +16,19 @@ import (
 )
 )
 
 
 type TestTun struct {
 type TestTun struct {
-	Device    string
-	cidr      netip.Prefix
-	Routes    []Route
-	routeTree *bart.Table[netip.Addr]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	Routes      []Route
+	routeTree   *bart.Table[netip.Addr]
+	l           *logrus.Logger
 
 
 	closed    atomic.Bool
 	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
-	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
+	_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
 	}
 	}
 
 
 	return &TestTun{
 	return &TestTun{
-		Device:    c.GetString("tun.dev", ""),
-		cidr:      cidr,
-		Routes:    routes,
-		routeTree: routeTree,
-		l:         l,
-		rxPackets: make(chan []byte, 10),
-		TxPackets: make(chan []byte, 10),
+		Device:      c.GetString("tun.dev", ""),
+		vpnNetworks: vpnNetworks,
+		Routes:      routes,
+		routeTree:   routeTree,
+		l:           l,
+		rxPackets:   make(chan []byte, 10),
+		TxPackets:   make(chan []byte, 10),
 	}, nil
 	}, nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 }
 
 
@@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *TestTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *TestTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *TestTun) Name() string {
 func (t *TestTun) Name() string {

Some files were not shown because too many files changed in this diff