瀏覽代碼

Merge remote-tracking branch 'origin/master' into mutex-debug

Wade Simmons 3 月之前
父節點
當前提交
1c9fdba403
共有 100 個文件被更改,包括 9305 次插入6644 次删除
  1. 1 1
      .github/workflows/gofmt.yml
  2. 5 5
      .github/workflows/release.yml
  3. 3 0
      .github/workflows/smoke-extra.yml
  4. 1 1
      .github/workflows/smoke.yml
  5. 7 3
      .github/workflows/smoke/build.sh
  6. 18 18
      .github/workflows/smoke/smoke-vagrant.sh
  7. 33 5
      .github/workflows/test.yml
  8. 3 1
      .gitignore
  9. 23 0
      .golangci.yaml
  10. 50 1
      CHANGELOG.md
  11. 14 5
      Makefile
  12. 2 2
      README.md
  13. 50 90
      allow_list.go
  14. 46 45
      allow_list_test.go
  15. 55 30
      calculated_remote.go
  16. 64 10
      calculated_remote_test.go
  17. 1 1
      cert/Makefile
  18. 15 4
      cert/README.md
  19. 52 0
      cert/asn1.go
  20. 0 140
      cert/ca.go
  21. 296 0
      cert/ca_pool.go
  22. 560 0
      cert/ca_pool_test.go
  23. 98 976
      cert/cert.go
  24. 0 1230
      cert/cert_test.go
  25. 489 0
      cert/cert_v1.go
  26. 111 111
      cert/cert_v1.pb.go
  27. 0 0
      cert/cert_v1.proto
  28. 218 0
      cert/cert_v1_test.go
  29. 37 0
      cert/cert_v2.asn1
  30. 730 0
      cert/cert_v2.go
  31. 267 0
      cert/cert_v2_test.go
  32. 159 2
      cert/crypto.go
  33. 90 2
      cert/crypto_test.go
  34. 41 6
      cert/errors.go
  35. 141 0
      cert/helper_test.go
  36. 161 0
      cert/pem.go
  37. 293 0
      cert/pem_test.go
  38. 167 0
      cert/sign.go
  39. 91 0
      cert/sign_test.go
  40. 138 0
      cert_test/cert.go
  41. 0 10
      cidr/parse.go
  42. 0 203
      cidr/tree4.go
  43. 0 170
      cidr/tree4_test.go
  44. 0 189
      cidr/tree6.go
  45. 0 98
      cidr/tree6_test.go
  46. 137 73
      cmd/nebula-cert/ca.go
  47. 75 68
      cmd/nebula-cert/ca_test.go
  48. 49 20
      cmd/nebula-cert/keygen.go
  49. 26 24
      cmd/nebula-cert/keygen_test.go
  50. 1 1
      cmd/nebula-cert/main.go
  51. 12 4
      cmd/nebula-cert/main_test.go
  52. 15 0
      cmd/nebula-cert/p11_cgo.go
  53. 16 0
      cmd/nebula-cert/p11_stub.go
  54. 14 9
      cmd/nebula-cert/print.go
  55. 158 34
      cmd/nebula-cert/print_test.go
  56. 238 110
      cmd/nebula-cert/sign.go
  57. 109 109
      cmd/nebula-cert/sign_test.go
  58. 26 15
      cmd/nebula-cert/verify.go
  59. 35 52
      cmd/nebula-cert/verify_test.go
  60. 33 17
      config/config.go
  61. 31 34
      config/config_test.go
  62. 67 38
      connection_manager.go
  63. 178 76
      connection_manager_test.go
  64. 35 23
      connection_state.go
  65. 78 43
      control.go
  66. 40 45
      control_test.go
  67. 51 37
      control_tester.go
  68. 86 40
      dns_server.go
  69. 28 13
      dns_server_test.go
  70. 394 173
      e2e/handshakes_test.go
  71. 0 118
      e2e/helpers.go
  72. 92 51
      e2e/helpers_test.go
  73. 9 8
      e2e/router/hostmap.go
  74. 72 77
      e2e/router/router.go
  75. 44 13
      examples/config.yml
  76. 18 9
      examples/go_service/main.go
  77. 128 110
      firewall.go
  78. 13 13
      firewall/packet.go
  79. 259 317
      firewall_test.go
  80. 25 22
      go.mod
  81. 48 41
      go.sum
  82. 303 117
      handshake_ix.go
  83. 215 149
      handshake_manager.go
  84. 25 18
      handshake_manager_test.go
  85. 1 1
      header/header.go
  86. 2 1
      header/header_test.go
  87. 209 138
      hostmap.go
  88. 27 46
      hostmap_test.go
  89. 5 3
      hostmap_tester.go
  90. 117 44
      inside.go
  91. 84 53
      interface.go
  92. 0 93
      iputil/util.go
  93. 0 17
      iputil/util_test.go
  94. 382 287
      lighthouse.go
  95. 256 247
      lighthouse_test.go
  96. 19 34
      main.go
  97. 0 2
      message_metrics.go
  98. 0 18
      metadata.go
  99. 467 171
      nebula.pb.go
  100. 23 9
      nebula.proto

+ 1 - 1
.github/workflows/gofmt.yml

@@ -18,7 +18,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.24'
         check-latest: true
 
     - name: Install goimports

+ 5 - 5
.github/workflows/release.yml

@@ -14,7 +14,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.22'
+          go-version: '1.24'
           check-latest: true
 
       - name: Build
@@ -37,7 +37,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.22'
+          go-version: '1.24'
           check-latest: true
 
       - name: Build
@@ -64,18 +64,18 @@ jobs:
     name: Build Universal Darwin
     env:
       HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }}
-    runs-on: macos-11
+    runs-on: macos-latest
     steps:
       - uses: actions/checkout@v4
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.22'
+          go-version: '1.24'
           check-latest: true
 
       - name: Import certificates
         if: env.HAS_SIGNING_CREDS == 'true'
-        uses: Apple-Actions/import-codesign-certs@v2
+        uses: Apple-Actions/import-codesign-certs@v5
         with:
           p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }}
           p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }}

+ 3 - 0
.github/workflows/smoke-extra.yml

@@ -27,6 +27,9 @@ jobs:
         go-version-file: 'go.mod'
         check-latest: true
 
+    - name: add hashicorp source
+      run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
+
     - name: install vagrant
       run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
 

+ 1 - 1
.github/workflows/smoke.yml

@@ -22,7 +22,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.24'
         check-latest: true
 
     - name: build

+ 7 - 3
.github/workflows/smoke/build.sh

@@ -5,6 +5,10 @@ set -e -x
 rm -rf ./build
 mkdir ./build
 
+# TODO: Assumes your docker bridge network is a /24, and the first container that launches will be .1
+# - We could make this better by launching the lighthouse first and then fetching what IP it is.
+NET="$(docker network inspect bridge -f '{{ range .IPAM.Config }}{{ .Subnet }}{{ end }}' | cut -d. -f1-3)"
+
 (
     cd build
 
@@ -21,16 +25,16 @@ mkdir ./build
         ../genconfig.sh >lighthouse1.yml
 
     HOST="host2" \
-        LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
+        LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
         ../genconfig.sh >host2.yml
 
     HOST="host3" \
-        LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
+        LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
         INBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
         ../genconfig.sh >host3.yml
 
     HOST="host4" \
-        LIGHTHOUSES="192.168.100.1 172.17.0.2:4242" \
+        LIGHTHOUSES="192.168.100.1 $NET.2:4242" \
         OUTBOUND='[{"port": "any", "proto": "icmp", "group": "lighthouse"}]' \
         ../genconfig.sh >host4.yml
 

+ 18 - 18
.github/workflows/smoke/smoke-vagrant.sh

@@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
 docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
 
 vagrant up
-vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
+vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
 
 docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
-vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
+vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 15
 
 # grab tcpdump pcaps for debugging
@@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
 # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
 # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
 
-docker exec host2 ncat -nklv 0.0.0.0 2000 &
-vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
+#docker exec host2 ncat -nklv 0.0.0.0 2000 &
+#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
 #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
 #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
 
@@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
 # Should fail because not allowed by host3 inbound firewall
 ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
 
-set +x
-echo
-echo " *** Testing ncat from host2"
-echo
-set -x
+#set +x
+#echo
+#echo " *** Testing ncat from host2"
+#echo
+#set -x
 # Should fail because not allowed by host3 inbound firewall
 #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
@@ -82,18 +82,18 @@ echo
 echo " *** Testing ping from host3"
 echo
 set -x
-vagrant ssh -c "ping -c1 192.168.100.1"
-vagrant ssh -c "ping -c1 192.168.100.2"
-
-set +x
-echo
-echo " *** Testing ncat from host3"
-echo
-set -x
+vagrant ssh -c "ping -c1 192.168.100.1" -- -T
+vagrant ssh -c "ping -c1 192.168.100.2" -- -T
+
+#set +x
+#echo
+#echo " *** Testing ncat from host3"
+#echo
+#set -x
 #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
 #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
 
-vagrant ssh -c "sudo xargs kill </nebula/pid"
+vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
 docker exec host2 sh -c 'kill 1'
 docker exec lighthouse1 sh -c 'kill 1'
 sleep 1

+ 33 - 5
.github/workflows/test.yml

@@ -22,7 +22,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.24'
         check-latest: true
 
     - name: Build
@@ -31,6 +31,11 @@ jobs:
     - name: Vet
       run: make vet
 
+    - name: golangci-lint
+      uses: golangci/golangci-lint-action@v7
+      with:
+        version: v2.0
+
     - name: Test
       run: make test
 
@@ -55,7 +60,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.24'
         check-latest: true
 
     - name: Build
@@ -65,21 +70,39 @@ jobs:
       run: make test-boringcrypto
 
     - name: End 2 end
-      run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+      run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
+
+  test-linux-pkcs11:
+    name: Build and test on linux with pkcs11
+    runs-on: ubuntu-latest
+    steps:
+
+    - uses: actions/checkout@v4
+
+    - uses: actions/setup-go@v5
+      with:
+        go-version: '1.22'
+        check-latest: true
+
+    - name: Build
+      run: make bin-pkcs11
+
+    - name: Test
+      run: make test-pkcs11
 
   test:
     name: Build and test on ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
     strategy:
       matrix:
-        os: [windows-latest, macos-11]
+        os: [windows-latest, macos-latest]
     steps:
 
     - uses: actions/checkout@v4
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.24'
         check-latest: true
 
     - name: Build nebula
@@ -91,6 +114,11 @@ jobs:
     - name: Vet
       run: make vet
 
+    - name: golangci-lint
+      uses: golangci/golangci-lint-action@v7
+      with:
+        version: v2.0
+
     - name: Test
       run: make test
 

+ 3 - 1
.gitignore

@@ -5,7 +5,8 @@
 /nebula-darwin
 /nebula.exe
 /nebula-cert.exe
-/coverage.out
+**/coverage.out
+**/cover.out
 /cpu.pprof
 /build
 /*.tar.gz
@@ -13,5 +14,6 @@
 **.crt
 **.key
 **.pem
+**.pub
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt

+ 23 - 0
.golangci.yaml

@@ -0,0 +1,23 @@
+version: "2"
+linters:
+  default: none
+  enable:
+    - testifylint
+  exclusions:
+    generated: lax
+    presets:
+      - comments
+      - common-false-positives
+      - legacy
+      - std-error-handling
+    paths:
+      - third_party$
+      - builtin$
+      - examples$
+formatters:
+  exclusions:
+    generated: lax
+    paths:
+      - third_party$
+      - builtin$
+      - examples$

+ 50 - 1
CHANGELOG.md

@@ -7,6 +7,51 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 
 ## [Unreleased]
 
+### Changed
+
+- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
+  intended to target an `unsafe_routes` entry must explicitly declare it via the
+  `local_cidr` field. This is almost always the intended behavior. This flag is
+  deprecated and will be removed in a future release.
+
+## [1.9.4] - 2024-09-09
+
+### Added
+
+- Support UDP dialing with gVisor. (#1181)
+
+### Changed
+
+- Make some Nebula state programmatically available via control object. (#1188)
+- Switch internal representation of IPs to netip, to prepare for IPv6 support
+  in the overlay. (#1173)
+- Minor build and cleanup changes. (#1171, #1164, #1162)
+- Various dependency updates. (#1195, #1190, #1174, #1168, #1167, #1161, #1147, #1146)
+
+### Fixed
+
+- Fix a bug on big endian hosts, like mips. (#1194)
+- Fix a rare panic if a local index collision happens. (#1191)
+- Fix integer wraparound in the calculation of handshake timeouts on 32-bit targets. (#1185)
+
+## [1.9.3] - 2024-06-06
+
+### Fixed
+
+- Initialize messageCounter to 2 instead of verifying later. (#1156)
+
+## [1.9.2] - 2024-06-03
+
+### Fixed
+
+- Ensure messageCounter is set before handshake is complete. (#1154)
+
+## [1.9.1] - 2024-05-29
+
+### Fixed
+
+- Fixed a potential deadlock in GetOrHandshake. (#1151)
+
 ## [1.9.0] - 2024-05-07
 
 ### Deprecated
@@ -626,7 +671,11 @@ created.)
 
 - Initial public release.
 
-[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.0...HEAD
+[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
+[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
+[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
+[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2
+[1.9.1]: https://github.com/slackhq/nebula/releases/tag/v1.9.1
 [1.9.0]: https://github.com/slackhq/nebula/releases/tag/v1.9.0
 [1.8.2]: https://github.com/slackhq/nebula/releases/tag/v1.8.2
 [1.8.1]: https://github.com/slackhq/nebula/releases/tag/v1.8.1

+ 14 - 5
Makefile

@@ -40,7 +40,7 @@ ALL_LINUX = linux-amd64 \
 	linux-mips64le \
 	linux-mips-softfloat \
 	linux-riscv64 \
-        linux-loong64
+	linux-loong64
 
 ALL_FREEBSD = freebsd-amd64 \
 	freebsd-arm64
@@ -66,7 +66,7 @@ e2e:
 e2e-mutex-debug:
 	$(TEST_ENV) go test -tags=mutex_debug,e2e_testing -count=1 $(TEST_FLAGS) ./e2e
 
-e2ev: TEST_FLAGS = -v
+e2ev: TEST_FLAGS += -v
 e2ev: e2e
 
 e2evv: TEST_ENV += TEST_LOGS=1
@@ -99,7 +99,7 @@ release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz)
 
 release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
 
-BUILD_ARGS = -trimpath
+BUILD_ARGS += -trimpath
 
 bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
 	mv $? .
@@ -119,6 +119,10 @@ bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert
 bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
 	mv $? .
 
+bin-pkcs11: BUILD_ARGS += -tags pkcs11
+bin-pkcs11: CGO_ENABLED = 1
+bin-pkcs11: bin
+
 bin:
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
@@ -136,6 +140,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
 # boringcrypto
 build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
+build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
 
 build/%/nebula: .FORCE
 	GOOS=$(firstword $(subst -, , $*)) \
@@ -169,7 +175,10 @@ test:
 	go test -v ./...
 
 test-boringcrypto:
-	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
+	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
+
+test-pkcs11:
+	CGO_ENABLED=1 go test -v -tags pkcs11 ./...
 
 test-cov-html:
 	go test -coverprofile=coverage.out
@@ -192,7 +201,7 @@ bench-cpu-long:
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 
-proto: nebula.pb.go cert/cert.pb.go
+proto: nebula.pb.go cert/cert_v1.pb.go
 
 nebula.pb.go: nebula.proto .FORCE
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster

+ 2 - 2
README.md

@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
 
 You can read more about Nebula [here](https://medium.com/p/884110a5579).
 
-You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
+You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
 
 ## Supported Platforms
 
@@ -47,7 +47,7 @@ Check the [releases](https://github.com/slackhq/nebula/releases/latest) page for
     $ sudo apk add nebula
     ```
 
-- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/nebula.rb)
+- [macOS Homebrew](https://github.com/Homebrew/homebrew-core/blob/HEAD/Formula/n/nebula.rb)
     ```
     $ brew install nebula
     ```

+ 50 - 90
allow_list.go

@@ -2,17 +2,16 @@ package nebula
 
 import (
 	"fmt"
-	"net"
+	"net/netip"
 	"regexp"
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *cidr.Tree6[bool]
+	cidrTree *bart.Table[bool]
 }
 
 type RemoteAllowList struct {
@@ -20,7 +19,7 @@ type RemoteAllowList struct {
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
-	insideAllowLists *cidr.Tree6[*AllowList]
+	insideAllowLists *bart.Table[*AllowList]
 }
 
 type LocalAllowList struct {
@@ -37,7 +36,7 @@ type AllowListNameRule struct {
 
 func NewLocalAllowListFromConfig(c *config.C, k string) (*LocalAllowList, error) {
 	var nameRules []AllowListNameRule
-	handleKey := func(key string, value interface{}) (bool, error) {
+	handleKey := func(key string, value any) (bool, error) {
 		if key == "interfaces" {
 			var err error
 			nameRules, err = getAllowListInterfaces(k, value)
@@ -71,7 +70,7 @@ func NewRemoteAllowListFromConfig(c *config.C, k, rangesKey string) (*RemoteAllo
 
 // If the handleKey func returns true, the rest of the parsing is skipped
 // for this key. This allows parsing of special values like `interfaces`.
-func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
+func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
 	r := c.Get(k)
 	if r == nil {
 		return nil, nil
@@ -82,13 +81,13 @@ func newAllowListFromConfig(c *config.C, k string, handleKey func(key string, va
 
 // If the handleKey func returns true, the rest of the parsing is skipped
 // for this key. This allows parsing of special values like `interfaces`.
-func newAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
-	rawMap, ok := raw.(map[interface{}]interface{})
+func newAllowList(k string, raw any, handleKey func(key string, value any) (bool, error)) (*AllowList, error) {
+	rawMap, ok := raw.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
 
-	tree := cidr.NewTree6[bool]()
+	tree := new(bart.Table[bool])
 
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	type allowListRules struct {
@@ -101,12 +100,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 	rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
 	rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
 
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
+	for rawCIDR, rawValue := range rawMap {
 		if handleKey != nil {
 			handled, err := handleKey(rawCIDR, rawValue)
 			if err != nil {
@@ -117,23 +111,24 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 			}
 		}
 
-		value, ok := rawValue.(bool)
+		value, ok := config.AsBool(rawValue)
 		if !ok {
 			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
 		}
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		ipNet, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
 		}
 
-		// TODO: should we error on duplicate CIDRs in the config?
-		tree.AddCIDR(ipNet, value)
+		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
 
-		maskBits, maskSize := ipNet.Mask.Size()
+		tree.Insert(ipNet, value)
+
+		maskBits := ipNet.Bits()
 
 		var rules *allowListRules
-		if maskSize == 32 {
+		if ipNet.Addr().Is4() {
 			rules = &rules4
 		} else {
 			rules = &rules6
@@ -156,8 +151,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 	if !rules4.defaultSet {
 		if rules4.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
-			tree.AddCIDR(zeroCIDR, !rules4.allValues)
+			tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
 		} else {
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
 		}
@@ -165,8 +159,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 	if !rules6.defaultSet {
 		if rules6.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("::/0")
-			tree.AddCIDR(zeroCIDR, !rules6.allValues)
+			tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
 		} else {
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
 		}
@@ -175,22 +168,18 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 	return &AllowList{cidrTree: tree}, nil
 }
 
-func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
+func getAllowListInterfaces(k string, v any) ([]AllowListNameRule, error) {
 	var nameRules []AllowListNameRule
 
-	rawRules, ok := v.(map[interface{}]interface{})
+	rawRules, ok := v.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
 	}
 
 	firstEntry := true
 	var allValues bool
-	for rawName, rawAllow := range rawRules {
-		name, ok := rawName.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
-		}
-		allow, ok := rawAllow.(bool)
+	for name, rawAllow := range rawRules {
+		allow, ok := config.AsBool(rawAllow)
 		if !ok {
 			return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
 		}
@@ -218,72 +207,49 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
 	return nameRules, nil
 }
 
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
+func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	remoteAllowRanges := cidr.NewTree6[*AllowList]()
+	remoteAllowRanges := new(bart.Table[*AllowList])
 
-	rawMap, ok := value.(map[interface{}]interface{})
+	rawMap, ok := value.(map[string]any)
 	if !ok {
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
 	}
-	for rawKey, rawValue := range rawMap {
-		rawCIDR, ok := rawKey.(string)
-		if !ok {
-			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
-		}
-
+	for rawCIDR, rawValue := range rawMap {
 		allowList, err := newAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
 		if err != nil {
 			return nil, err
 		}
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		ipNet, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
 		}
 
-		remoteAllowRanges.AddCIDR(ipNet, allowList)
+		remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
 	}
 
 	return remoteAllowRanges, nil
 }
 
-func (al *AllowList) Allow(ip net.IP) bool {
+func (al *AllowList) Allow(addr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
 
-	_, result := al.cidrTree.MostSpecificContains(ip)
+	result, _ := al.cidrTree.Lookup(addr)
 	return result
 }
 
-func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
+func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-
-	_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
-	return result
-}
-
-func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
-	if al == nil {
-		return true
-	}
-
-	_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
-	return result
-}
-
-func (al *LocalAllowList) Allow(ip net.IP) bool {
-	if al == nil {
-		return true
-	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
 func (al *LocalAllowList) AllowName(name string) bool {
@@ -301,43 +267,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 }
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
+func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(vpnAddr)
 }
 
-func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
-	if !al.getInsideAllowList(vpnIp).Allow(ip) {
+func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
+	if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
 		return false
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
-func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
-	if al == nil {
-		return true
-	}
-	if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
+func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
+	if !al.AllowList.Allow(udpAddr) {
 		return false
 	}
-	return al.AllowList.AllowIpV4(ip)
-}
 
-func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
-	if al == nil {
-		return true
-	}
-	if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
-		return false
+	for _, vpnAddr := range vpnAddrs {
+		if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
+			return false
+		}
 	}
-	return al.AllowList.AllowIpV6(hi, lo)
+
+	return true
 }
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
-		ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnAddr)
 		if ok {
 			return inside
 		}

+ 46 - 45
allow_list_test.go

@@ -1,40 +1,41 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"regexp"
 	"testing"
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestNewAllowListFromConfig(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"192.168.0.0": true,
 	}
 	r, err := newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+	require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
 	assert.Nil(t, r)
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"192.168.0.0/16": "abc",
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
+	require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"192.168.0.0/16": true,
 		"10.0.0.0/8":     false,
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
+	require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"0.0.0.0/0":      true,
 		"10.0.0.0/8":     false,
 		"10.42.42.0/24":  true,
@@ -42,9 +43,9 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		"fd00:fd00::/16": false,
 	}
 	r, err = newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
+	require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"0.0.0.0/0":     true,
 		"10.0.0.0/8":    false,
 		"10.42.42.0/24": true,
@@ -54,7 +55,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		assert.NotNil(t, r)
 	}
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
 		"0.0.0.0/0":      true,
 		"10.0.0.0/8":     false,
 		"10.42.42.0/24":  true,
@@ -69,25 +70,25 @@ func TestNewAllowListFromConfig(t *testing.T) {
 
 	// Test interface names
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
+		"interfaces": map[string]any{
 			`docker.*`: "foo",
 		},
 	}
 	lr, err := NewLocalAllowListFromConfig(c, "allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
+	require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
+		"interfaces": map[string]any{
 			`docker.*`: false,
 			`eth.*`:    true,
 		},
 	}
 	lr, err = NewLocalAllowListFromConfig(c, "allowlist")
-	assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
+	require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
 
-	c.Settings["allowlist"] = map[interface{}]interface{}{
-		"interfaces": map[interface{}]interface{}{
+	c.Settings["allowlist"] = map[string]any{
+		"interfaces": map[string]any{
 			`docker.*`: false,
 		},
 	}
@@ -98,30 +99,30 @@ func TestNewAllowListFromConfig(t *testing.T) {
 }
 
 func TestAllowList_Allow(t *testing.T) {
-	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
-
-	tree := cidr.NewTree6[bool]()
-	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
-	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
-	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
-	tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
-	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
-	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
-	tree.AddCIDR(cidr.Parse("::1/128"), true)
-	tree.AddCIDR(cidr.Parse("::2/128"), false)
+	assert.True(t, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
+
+	tree := new(bart.Table[bool])
+	tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
+	tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
+	tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
+	tree.Insert(netip.MustParsePrefix("::1/128"), true)
+	tree.Insert(netip.MustParsePrefix("::2/128"), false)
 	al := &AllowList{cidrTree: tree}
 
-	assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
+	assert.True(t, al.Allow(netip.MustParseAddr("1.1.1.1")))
+	assert.False(t, al.Allow(netip.MustParseAddr("10.0.0.4")))
+	assert.True(t, al.Allow(netip.MustParseAddr("10.42.42.42")))
+	assert.False(t, al.Allow(netip.MustParseAddr("10.42.42.41")))
+	assert.True(t, al.Allow(netip.MustParseAddr("10.42.0.1")))
+	assert.True(t, al.Allow(netip.MustParseAddr("::1")))
+	assert.False(t, al.Allow(netip.MustParseAddr("::2")))
 }
 
 func TestLocalAllowList_AllowName(t *testing.T) {
-	assert.Equal(t, true, ((*LocalAllowList)(nil)).AllowName("docker0"))
+	assert.True(t, ((*LocalAllowList)(nil)).AllowName("docker0"))
 
 	rules := []AllowListNameRule{
 		{Name: regexp.MustCompile("^docker.*$"), Allow: false},
@@ -129,9 +130,9 @@ func TestLocalAllowList_AllowName(t *testing.T) {
 	}
 	al := &LocalAllowList{nameRules: rules}
 
-	assert.Equal(t, false, al.AllowName("docker0"))
-	assert.Equal(t, false, al.AllowName("tun0"))
-	assert.Equal(t, true, al.AllowName("eth0"))
+	assert.False(t, al.AllowName("docker0"))
+	assert.False(t, al.AllowName("tun0"))
+	assert.True(t, al.AllowName("eth0"))
 
 	rules = []AllowListNameRule{
 		{Name: regexp.MustCompile("^eth.*$"), Allow: true},
@@ -139,7 +140,7 @@ func TestLocalAllowList_AllowName(t *testing.T) {
 	}
 	al = &LocalAllowList{nameRules: rules}
 
-	assert.Equal(t, false, al.AllowName("docker0"))
-	assert.Equal(t, true, al.AllowName("eth0"))
-	assert.Equal(t, true, al.AllowName("ens5"))
+	assert.False(t, al.AllowName("docker0"))
+	assert.True(t, al.AllowName("eth0"))
+	assert.True(t, al.AllowName("ens5"))
 }

+ 55 - 30
calculated_remote.go

@@ -1,41 +1,40 @@
 package nebula
 
 import (
+	"encoding/binary"
 	"fmt"
 	"math"
 	"net"
+	"net/netip"
 	"strconv"
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 // This allows us to "guess" what the remote might be for a host while we wait
 // for the lighthouse response. See "lighthouse.calculated_remotes" in the
 // example config file.
 type calculatedRemote struct {
-	ipNet  net.IPNet
-	maskIP iputil.VpnIp
-	mask   iputil.VpnIp
-	port   uint32
+	ipNet netip.Prefix
+	mask  netip.Prefix
+	port  uint32
 }
 
-func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
-	// Ensure this is an IPv4 mask that we expect
-	ones, bits := ipNet.Mask.Size()
-	if ones == 0 || bits != 32 {
-		return nil, fmt.Errorf("invalid mask: %v", ipNet)
+func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
+		return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
 	}
+
+	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
 	}
 
 	return &calculatedRemote{
-		ipNet:  *ipNet,
-		maskIP: iputil.Ip2VpnIp(ipNet.IP),
-		mask:   iputil.Ip2VpnIp(ipNet.Mask),
-		port:   uint32(port),
+		ipNet: maskCidr,
+		mask:  masked,
+		port:  uint32(port),
 	}, nil
 }
 
@@ -43,21 +42,47 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 
-func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
-	// Combine the masked bytes of the "mask" IP with the unmasked bytes
-	// of the overlay IP
-	masked := (c.maskIP & c.mask) | (ip & ^c.mask)
+func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
+	// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
+	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	mask := binary.BigEndian.Uint32(maskb[:])
+
+	b := c.mask.Addr().As4()
+	maskAddr := binary.BigEndian.Uint32(b[:])
+
+	b = addr.As4()
+	intAddr := binary.BigEndian.Uint32(b[:])
+
+	return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
+}
+
+func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
+	mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	maskAddr := c.mask.Addr().As16()
+	calcAddr := addr.As16()
+
+	ap := V6AddrPort{Port: c.port}
+
+	maskb := binary.BigEndian.Uint64(mask[:8])
+	maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
+	calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
+	ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	maskb = binary.BigEndian.Uint64(mask[8:])
+	maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
+	calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
+	ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
 
-	return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
+	return &ap
 }
 
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
+	calculatedRemotes := new(bart.Table[[]*calculatedRemote])
 
 	rawMap, ok := value.(map[any]any)
 	if !ok {
@@ -69,23 +94,23 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
 		}
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		cidr, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 
-		entry, err := newCalculatedRemotesListFromConfig(rawValue)
+		entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
 
-		calculatedRemotes.AddCIDR(ipNet, entry)
+		calculatedRemotes.Insert(cidr, entry)
 	}
 
 	return calculatedRemotes, nil
 }
 
-func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
+func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
 	rawList, ok := raw.([]any)
 	if !ok {
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@@ -93,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 
 	var l []*calculatedRemote
 	for _, e := range rawList {
-		c, err := newCalculatedRemotesEntryFromConfig(e)
+		c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
 		if err != nil {
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 		}
@@ -103,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 	return l, nil
 }
 
-func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
+func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
 	rawMap, ok := raw.(map[any]any)
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
@@ -117,7 +142,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 	if !ok {
 		return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
 	}
-	_, ipNet, err := net.ParseCIDR(rawMask)
+	maskCidr, err := netip.ParsePrefix(rawMask)
 	if err != nil {
 		return nil, fmt.Errorf("invalid mask: %s", rawMask)
 	}
@@ -139,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 
-	return newCalculatedRemote(ipNet, port)
+	return newCalculatedRemote(cidr, maskCidr, port)
 }

+ 64 - 10
calculated_remote_test.go

@@ -1,27 +1,81 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"testing"
 
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 )
 
 func TestCalculatedRemoteApply(t *testing.T) {
-	_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
+	// Test v4 addresses
+	ipNet := netip.MustParsePrefix("192.168.1.0/24")
+	c, err := newCalculatedRemote(ipNet, ipNet, 4242)
 	require.NoError(t, err)
 
-	c, err := newCalculatedRemote(ipNet, 4242)
+	input, err := netip.ParseAddr("10.0.10.182")
 	require.NoError(t, err)
 
-	input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
+	expected, err := netip.ParseAddr("192.168.1.182")
+	require.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
+
+	// Test v6 addresses
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	require.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
+	require.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	require.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
+	require.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	require.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
+	require.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+}
+
+func Test_newCalculatedRemote(t *testing.T) {
+	c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
+	require.Nil(t, c)
 
-	expected := &Ip4AndPort{
-		Ip:   uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
-		Port: 4242,
-	}
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
+	require.Nil(t, c)
 
-	assert.Equal(t, expected, c.Apply(input))
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
 }

+ 1 - 1
cert/Makefile

@@ -1,7 +1,7 @@
 GO111MODULE = on
 export GO111MODULE
 
-cert.pb.go: cert.proto .FORCE
+cert_v1.pb.go: cert_v1.proto .FORCE
 	go build google.golang.org/protobuf/cmd/protoc-gen-go
 	PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
 	rm protoc-gen-go

+ 15 - 4
cert/README.md

@@ -2,14 +2,25 @@
 
 This is a library for interacting with `nebula` style certificates and authorities.
 
-A `protobuf` definition of the certificate format is also included
+There are now 2 versions of `nebula` certificates:
 
-### Compiling the protobuf definition
+## v1
 
-Make sure you have `protoc` installed.
+This version is deprecated.
+
+A `protobuf` definition of the certificate format is included at `cert_v1.proto`
+
+To compile the definition you will need `protoc` installed.
 
 To compile for `go` with the same version of protobuf specified in go.mod:
 
 ```bash
-make
+make proto
 ```
+
+## v2
+
+This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
+future certificate changes better than v1.
+
+`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

+ 52 - 0
cert/asn1.go

@@ -0,0 +1,52 @@
+package cert
+
+import (
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+)
+
+// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
+// https://github.com/golang/go/issues/64811#issuecomment-1944446920
+func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0] > 0
+		return true
+	}
+
+	return false
+}
+
+// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
+// Similar issue as with readOptionalASN1Boolean
+func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0]
+		return true
+	}
+
+	return false
+}

+ 0 - 140
cert/ca.go

@@ -1,140 +0,0 @@
-package cert
-
-import (
-	"errors"
-	"fmt"
-	"strings"
-	"time"
-)
-
-type NebulaCAPool struct {
-	CAs           map[string]*NebulaCertificate
-	certBlocklist map[string]struct{}
-}
-
-// NewCAPool creates a CAPool
-func NewCAPool() *NebulaCAPool {
-	ca := NebulaCAPool{
-		CAs:           make(map[string]*NebulaCertificate),
-		certBlocklist: make(map[string]struct{}),
-	}
-
-	return &ca
-}
-
-// NewCAPoolFromBytes will create a new CA pool from the provided
-// input bytes, which must be a PEM-encoded set of nebula certificates.
-// If the pool contains any expired certificates, an ErrExpired will be
-// returned along with the pool. The caller must handle any such errors.
-func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
-	pool := NewCAPool()
-	var err error
-	var expired bool
-	for {
-		caPEMs, err = pool.AddCACertificate(caPEMs)
-		if errors.Is(err, ErrExpired) {
-			expired = true
-			err = nil
-		}
-		if err != nil {
-			return nil, err
-		}
-		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
-			break
-		}
-	}
-
-	if expired {
-		return pool, ErrExpired
-	}
-
-	return pool, nil
-}
-
-// AddCACertificate verifies a Nebula CA certificate and adds it to the pool
-// Only the first pem encoded object will be consumed, any remaining bytes are returned.
-// Parsed certificates will be verified and must be a CA
-func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
-	c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
-	if err != nil {
-		return pemBytes, err
-	}
-
-	if !c.Details.IsCA {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
-	}
-
-	if !c.CheckSignature(c.Details.PublicKey) {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
-	}
-
-	sum, err := c.Sha256Sum()
-	if err != nil {
-		return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
-	}
-
-	ncp.CAs[sum] = c
-	if c.Expired(time.Now()) {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
-	}
-
-	return pemBytes, nil
-}
-
-// BlocklistFingerprint adds a cert fingerprint to the blocklist
-func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
-	ncp.certBlocklist[f] = struct{}{}
-}
-
-// ResetCertBlocklist removes all previously blocklisted cert fingerprints
-func (ncp *NebulaCAPool) ResetCertBlocklist() {
-	ncp.certBlocklist = make(map[string]struct{})
-}
-
-// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated
-// automatically if you manually change any fields in the NebulaCertificate.
-func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
-	return ncp.isBlocklistedWithCache(c, false)
-}
-
-// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
-func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool {
-	h, err := c.sha256SumWithCache(useCache)
-	if err != nil {
-		return true
-	}
-
-	if _, ok := ncp.certBlocklist[h]; ok {
-		return true
-	}
-
-	return false
-}
-
-// GetCAForCert attempts to return the signing certificate for the provided certificate.
-// No signature validation is performed
-func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
-	if c.Details.Issuer == "" {
-		return nil, fmt.Errorf("no issuer in certificate")
-	}
-
-	signer, ok := ncp.CAs[c.Details.Issuer]
-	if ok {
-		return signer, nil
-	}
-
-	return nil, fmt.Errorf("could not find ca for the certificate")
-}
-
-// GetFingerprints returns an array of trusted CA fingerprints
-func (ncp *NebulaCAPool) GetFingerprints() []string {
-	fp := make([]string, len(ncp.CAs))
-
-	i := 0
-	for k := range ncp.CAs {
-		fp[i] = k
-		i++
-	}
-
-	return fp
-}

+ 296 - 0
cert/ca_pool.go

@@ -0,0 +1,296 @@
+package cert
+
+import (
+	"errors"
+	"fmt"
+	"net/netip"
+	"slices"
+	"strings"
+	"time"
+)
+
+type CAPool struct {
+	CAs           map[string]*CachedCertificate
+	certBlocklist map[string]struct{}
+}
+
+// NewCAPool creates an empty CAPool
+func NewCAPool() *CAPool {
+	ca := CAPool{
+		CAs:           make(map[string]*CachedCertificate),
+		certBlocklist: make(map[string]struct{}),
+	}
+
+	return &ca
+}
+
+// NewCAPoolFromPEM will create a new CA pool from the provided
+// input bytes, which must be a PEM-encoded set of nebula certificates.
+// If the pool contains any expired certificates, an ErrExpired will be
+// returned along with the pool. The caller must handle any such errors.
+func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
+	pool := NewCAPool()
+	var err error
+	var expired bool
+	for {
+		caPEMs, err = pool.AddCAFromPEM(caPEMs)
+		if errors.Is(err, ErrExpired) {
+			expired = true
+			err = nil
+		}
+		if err != nil {
+			return nil, err
+		}
+		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
+			break
+		}
+	}
+
+	if expired {
+		return pool, ErrExpired
+	}
+
+	return pool, nil
+}
+
+// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
+// Only the first pem encoded object will be consumed, any remaining bytes are returned.
+// Parsed certificates will be verified and must be a CA
+func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
+	c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
+	if err != nil {
+		return pemBytes, err
+	}
+
+	err = ncp.AddCA(c)
+	if err != nil {
+		return pemBytes, err
+	}
+
+	return pemBytes, nil
+}
+
+// AddCA verifies a Nebula CA certificate and adds it to the pool.
+func (ncp *CAPool) AddCA(c Certificate) error {
+	if !c.IsCA() {
+		return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
+	}
+
+	if !c.CheckSignature(c.PublicKey()) {
+		return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
+	}
+
+	sum, err := c.Fingerprint()
+	if err != nil {
+		return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
+	}
+
+	cc := &CachedCertificate{
+		Certificate:    c,
+		Fingerprint:    sum,
+		InvertedGroups: make(map[string]struct{}),
+	}
+
+	for _, g := range c.Groups() {
+		cc.InvertedGroups[g] = struct{}{}
+	}
+
+	ncp.CAs[sum] = cc
+
+	if c.Expired(time.Now()) {
+		return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
+	}
+
+	return nil
+}
+
+// BlocklistFingerprint adds a cert fingerprint to the blocklist
+func (ncp *CAPool) BlocklistFingerprint(f string) {
+	ncp.certBlocklist[f] = struct{}{}
+}
+
+// ResetCertBlocklist removes all previously blocklisted cert fingerprints
+func (ncp *CAPool) ResetCertBlocklist() {
+	ncp.certBlocklist = make(map[string]struct{})
+}
+
+// IsBlocklisted tests the provided fingerprint against the pools blocklist.
+// Returns true if the fingerprint is blocked.
+func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
+	if _, ok := ncp.certBlocklist[fingerprint]; ok {
+		return true
+	}
+
+	return false
+}
+
+// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
+// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
+// to increase performance.
+func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
+	if c == nil {
+		return nil, fmt.Errorf("no certificate")
+	}
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
+	}
+
+	signer, err := ncp.verify(c, now, fp, "")
+	if err != nil {
+		return nil, err
+	}
+
+	cc := CachedCertificate{
+		Certificate:       c,
+		InvertedGroups:    make(map[string]struct{}),
+		Fingerprint:       fp,
+		signerFingerprint: signer.Fingerprint,
+	}
+
+	for _, g := range c.Groups() {
+		cc.InvertedGroups[g] = struct{}{}
+	}
+
+	return &cc, nil
+}
+
+// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
+// is a cheaper operation to perform as a result.
+func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
+	_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
+	return err
+}
+
+func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
+	if ncp.IsBlocklisted(certFp) {
+		return nil, ErrBlockListed
+	}
+
+	signer, err := ncp.GetCAForCert(c)
+	if err != nil {
+		return nil, err
+	}
+
+	if signer.Certificate.Expired(now) {
+		return nil, ErrRootExpired
+	}
+
+	if c.Expired(now) {
+		return nil, ErrExpired
+	}
+
+	// If we are checking a cached certificate then we can bail early here
+	// Either the root is no longer trusted or everything is fine
+	if len(signerFp) > 0 {
+		if signerFp != signer.Fingerprint {
+			return nil, ErrFingerprintMismatch
+		}
+		return signer, nil
+	}
+	if !c.CheckSignature(signer.Certificate.PublicKey()) {
+		return nil, ErrSignatureMismatch
+	}
+
+	err = CheckCAConstraints(signer.Certificate, c)
+	if err != nil {
+		return nil, err
+	}
+
+	return signer, nil
+}
+
+// GetCAForCert attempts to return the signing certificate for the provided certificate.
+// No signature validation is performed
+func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
+	issuer := c.Issuer()
+	if issuer == "" {
+		return nil, fmt.Errorf("no issuer in certificate")
+	}
+
+	signer, ok := ncp.CAs[issuer]
+	if ok {
+		return signer, nil
+	}
+
+	return nil, ErrCaNotFound
+}
+
+// GetFingerprints returns an array of trusted CA fingerprints
+func (ncp *CAPool) GetFingerprints() []string {
+	fp := make([]string, len(ncp.CAs))
+
+	i := 0
+	for k := range ncp.CAs {
+		fp[i] = k
+		i++
+	}
+
+	return fp
+}
+
+// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
+func CheckCAConstraints(signer Certificate, sub Certificate) error {
+	return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
+}
+
+// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
+func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
+	// Make sure this cert isn't valid after the root
+	if notAfter.After(signer.NotAfter()) {
+		return fmt.Errorf("certificate expires after signing certificate")
+	}
+
+	// Make sure this cert wasn't valid before the root
+	if notBefore.Before(signer.NotBefore()) {
+		return fmt.Errorf("certificate is valid before the signing certificate")
+	}
+
+	// If the signer has a limited set of groups make sure the cert only contains a subset
+	signerGroups := signer.Groups()
+	if len(signerGroups) > 0 {
+		for _, g := range groups {
+			if !slices.Contains(signerGroups, g) {
+				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
+			}
+		}
+	}
+
+	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
+	signingNetworks := signer.Networks()
+	if len(signingNetworks) > 0 {
+		for _, certNetwork := range networks {
+			found := false
+			for _, signingNetwork := range signingNetworks {
+				if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
+			}
+		}
+	}
+
+	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
+	signingUnsafeNetworks := signer.UnsafeNetworks()
+	if len(signingUnsafeNetworks) > 0 {
+		for _, certUnsafeNetwork := range unsafeNetworks {
+			found := false
+			for _, caNetwork := range signingUnsafeNetworks {
+				if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
+			}
+		}
+	}
+
+	return nil
+}

+ 560 - 0
cert/ca_pool_test.go

@@ -0,0 +1,560 @@
+package cert
+
+import (
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestNewCAPoolFromBytes(t *testing.T) {
+	noNewLines := `
+# Current provisional, Remove once everything moves over to the real root.
+-----BEGIN NEBULA CERTIFICATE-----
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
+-----END NEBULA CERTIFICATE-----
+# root-ca01
+-----BEGIN NEBULA CERTIFICATE-----
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
+-----END NEBULA CERTIFICATE-----
+`
+
+	withNewLines := `
+# Current provisional, Remove once everything moves over to the real root.
+
+-----BEGIN NEBULA CERTIFICATE-----
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
+-----END NEBULA CERTIFICATE-----
+
+# root-ca01
+
+
+-----BEGIN NEBULA CERTIFICATE-----
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
+-----END NEBULA CERTIFICATE-----
+
+`
+
+	expired := `
+# expired certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
+7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
+Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
+-----END NEBULA CERTIFICATE-----
+`
+
+	p256 := `
+# p256 certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
+k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
+75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
+-----END NEBULA CERTIFICATE-----
+`
+
+	rootCA := certificateV1{
+		details: detailsV1{
+			name: "nebula root ca",
+		},
+	}
+
+	rootCA01 := certificateV1{
+		details: detailsV1{
+			name: "nebula root ca 01",
+		},
+	}
+
+	rootCAP256 := certificateV1{
+		details: detailsV1{
+			name: "nebula P256 test",
+		},
+	}
+
+	p, err := NewCAPoolFromPEM([]byte(noNewLines))
+	require.NoError(t, err)
+	assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+
+	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
+	require.NoError(t, err)
+	assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+
+	// expired cert, no valid certs
+	ppp, err := NewCAPoolFromPEM([]byte(expired))
+	assert.Equal(t, ErrExpired, err)
+	assert.Equal(t, "expired", ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
+
+	// expired cert, with valid certs
+	pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
+	assert.Equal(t, ErrExpired, err)
+	assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+	assert.Equal(t, "expired", pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name())
+	assert.Len(t, pppp.CAs, 3)
+
+	ppppp, err := NewCAPoolFromPEM([]byte(p256))
+	require.NoError(t, err)
+	assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
+	assert.Len(t, ppppp.CAs, 1)
+}
+
+func TestCertificateV1_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	require.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	require.NoError(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	require.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV1_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	require.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	require.NoError(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	require.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV1_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV1_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV2_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	require.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	require.NoError(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	require.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV2_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	require.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	require.NoError(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	require.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV2_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}
+
+func TestCertificateV2_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	require.NoError(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	require.NoError(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	require.NoError(t, err)
+}

+ 98 - 976
cert/cert.go

@@ -1,1029 +1,151 @@
 package cert
 
 import (
-	"bytes"
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/ed25519"
-	"crypto/elliptic"
-	"crypto/rand"
-	"crypto/sha256"
-	"encoding/binary"
-	"encoding/hex"
-	"encoding/json"
-	"encoding/pem"
-	"errors"
 	"fmt"
-	"math"
-	"math/big"
-	"net"
-	"sync/atomic"
+	"net/netip"
 	"time"
-
-	"golang.org/x/crypto/curve25519"
-	"google.golang.org/protobuf/proto"
 )
 
-const publicKeyLen = 32
+type Version uint8
 
 const (
-	CertBanner                       = "NEBULA CERTIFICATE"
-	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
-	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
-	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
-	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
-	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
-
-	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
-	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
-	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
-	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+	VersionPre1 Version = 0
+	Version1    Version = 1
+	Version2    Version = 2
 )
 
-type NebulaCertificate struct {
-	Details   NebulaCertificateDetails
-	Signature []byte
-
-	// the cached hex string of the calculated sha256sum
-	// for VerifyWithCache
-	sha256sum atomic.Pointer[string]
-
-	// the cached public key bytes if they were verified as the signer
-	// for VerifyWithCache
-	signatureVerified atomic.Pointer[[]byte]
-}
-
-type NebulaCertificateDetails struct {
-	Name      string
-	Ips       []*net.IPNet
-	Subnets   []*net.IPNet
-	Groups    []string
-	NotBefore time.Time
-	NotAfter  time.Time
-	PublicKey []byte
-	IsCA      bool
-	Issuer    string
-
-	// Map of groups for faster lookup
-	InvertedGroups map[string]struct{}
-
-	Curve Curve
-}
-
-type NebulaEncryptedData struct {
-	EncryptionMetadata NebulaEncryptionMetadata
-	Ciphertext         []byte
-}
-
-type NebulaEncryptionMetadata struct {
-	EncryptionAlgorithm string
-	Argon2Parameters    Argon2Parameters
-}
-
-type m map[string]interface{}
-
-// Returned if we try to unmarshal an encrypted private key without a passphrase
-var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
-
-// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert
-func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
-	if len(b) == 0 {
-		return nil, fmt.Errorf("nil byte array")
-	}
-	var rc RawNebulaCertificate
-	err := proto.Unmarshal(b, &rc)
-	if err != nil {
-		return nil, err
-	}
-
-	if rc.Details == nil {
-		return nil, fmt.Errorf("encoded Details was nil")
-	}
-
-	if len(rc.Details.Ips)%2 != 0 {
-		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
-	}
-
-	if len(rc.Details.Subnets)%2 != 0 {
-		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
-	}
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           rc.Details.Name,
-			Groups:         make([]string, len(rc.Details.Groups)),
-			Ips:            make([]*net.IPNet, len(rc.Details.Ips)/2),
-			Subnets:        make([]*net.IPNet, len(rc.Details.Subnets)/2),
-			NotBefore:      time.Unix(rc.Details.NotBefore, 0),
-			NotAfter:       time.Unix(rc.Details.NotAfter, 0),
-			PublicKey:      make([]byte, len(rc.Details.PublicKey)),
-			IsCA:           rc.Details.IsCA,
-			InvertedGroups: make(map[string]struct{}),
-			Curve:          rc.Details.Curve,
-		},
-		Signature: make([]byte, len(rc.Signature)),
-	}
-
-	copy(nc.Signature, rc.Signature)
-	copy(nc.Details.Groups, rc.Details.Groups)
-	nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer)
-
-	if len(rc.Details.PublicKey) < publicKeyLen {
-		return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
-	}
-	copy(nc.Details.PublicKey, rc.Details.PublicKey)
-
-	for i, rawIp := range rc.Details.Ips {
-		if i%2 == 0 {
-			nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)}
-		} else {
-			nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp))
-		}
-	}
-
-	for i, rawIp := range rc.Details.Subnets {
-		if i%2 == 0 {
-			nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)}
-		} else {
-			nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp))
-		}
-	}
-
-	for _, g := range rc.Details.Groups {
-		nc.Details.InvertedGroups[g] = struct{}{}
-	}
-
-	return &nc, nil
-}
-
-// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data
-// or an error on failure
-func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
-	p, r := pem.Decode(b)
-	if p == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if p.Type != CertBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
-	}
-	nc, err := UnmarshalNebulaCertificate(p.Bytes)
-	return nc, r, err
-}
-
-func MarshalPrivateKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
-
-func MarshalSigningPrivateKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
-
-// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key
-func MarshalX25519PrivateKey(b []byte) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
-}
-
-// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key
-func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
-}
-
-func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var expectedLen int
-	var curve Curve
-	switch k.Type {
-	case X25519PrivateKeyBanner:
-		expectedLen = 32
-		curve = Curve_CURVE25519
-	case P256PrivateKeyBanner:
-		expectedLen = 32
-		curve = Curve_P256
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner")
-	}
-	if len(k.Bytes) != expectedLen {
-		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
-	}
-	return k.Bytes, r, curve, nil
-}
-
-func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var curve Curve
-	switch k.Type {
-	case EncryptedEd25519PrivateKeyBanner:
-		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
-	case EncryptedECDSAP256PrivateKeyBanner:
-		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
-	case Ed25519PrivateKeyBanner:
-		curve = Curve_CURVE25519
-		if len(k.Bytes) != ed25519.PrivateKeySize {
-			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
-		}
-	case ECDSAP256PrivateKeyBanner:
-		curve = Curve_P256
-		if len(k.Bytes) != 32 {
-			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
-		}
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
-	}
-	return k.Bytes, r, curve, nil
-}
-
-// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
-func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
-	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
-	if err != nil {
-		return nil, err
-	}
-
-	b, err = proto.Marshal(&RawNebulaEncryptedData{
-		EncryptionMetadata: &RawNebulaEncryptionMetadata{
-			EncryptionAlgorithm: "AES-256-GCM",
-			Argon2Parameters: &RawNebulaArgon2Parameters{
-				Version:     kdfParams.version,
-				Memory:      kdfParams.Memory,
-				Parallelism: uint32(kdfParams.Parallelism),
-				Iterations:  kdfParams.Iterations,
-				Salt:        kdfParams.salt,
-			},
-		},
-		Ciphertext: ciphertext,
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
-	default:
-		return nil, fmt.Errorf("invalid curve: %v", curve)
-	}
-}
-
-// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
-// or an error on failure
-func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != X25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner")
-	}
-	if len(k.Bytes) != publicKeyLen {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b
-// or an error on failure
-func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-
-	if k.Type == EncryptedEd25519PrivateKeyBanner {
-		return nil, r, ErrPrivateKeyEncrypted
-	} else if k.Type != Ed25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner")
-	}
-
-	if len(k.Bytes) != ed25519.PrivateKeySize {
-		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
-// protobuf-generated struct.
-func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
-	if len(b) == 0 {
-		return nil, fmt.Errorf("nil byte array")
-	}
-	var rned RawNebulaEncryptedData
-	err := proto.Unmarshal(b, &rned)
-	if err != nil {
-		return nil, err
-	}
-
-	if rned.EncryptionMetadata == nil {
-		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
-	}
-
-	if rned.EncryptionMetadata.Argon2Parameters == nil {
-		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
-	}
-
-	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
-	if err != nil {
-		return nil, err
-	}
-
-	ned := NebulaEncryptedData{
-		EncryptionMetadata: NebulaEncryptionMetadata{
-			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
-			Argon2Parameters:    *params,
-		},
-		Ciphertext: rned.Ciphertext,
-	}
-
-	return &ned, nil
-}
-
-func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
-	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
-		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
-	}
-	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
-		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
-	}
-	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
-		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
-	}
-	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
-		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
-	}
-
-	return &Argon2Parameters{
-		version:     rune(params.Version),
-		Memory:      uint32(params.Memory),
-		Parallelism: uint8(params.Parallelism),
-		Iterations:  uint32(params.Iterations),
-		salt:        params.Salt,
-	}, nil
-
-}
-
-// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
-// the given passphrase, returning any other bytes b or an error on failure
-func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
-	var curve Curve
-
-	k, r := pem.Decode(b)
-	if k == nil {
-		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-
-	switch k.Type {
-	case EncryptedEd25519PrivateKeyBanner:
-		curve = Curve_CURVE25519
-	case EncryptedECDSAP256PrivateKeyBanner:
-		curve = Curve_P256
-	default:
-		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
-	}
-
-	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
-	if err != nil {
-		return curve, nil, r, err
-	}
-
-	var bytes []byte
-	switch ned.EncryptionMetadata.EncryptionAlgorithm {
-	case "AES-256-GCM":
-		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
-		if err != nil {
-			return curve, nil, r, err
-		}
-	default:
-		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
-	}
-
-	switch curve {
-	case Curve_CURVE25519:
-		if len(bytes) != ed25519.PrivateKeySize {
-			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
-		}
-	case Curve_P256:
-		if len(bytes) != 32 {
-			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
-		}
-	}
-
-	return curve, bytes, r, nil
-}
-
-func MarshalPublicKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
+type Certificate interface {
+	// Version defines the underlying certificate structure and wire protocol version
+	// Version1 certificates are ipv4 only and uses protobuf serialization
+	// Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization
+	Version() Version
 
-// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
-func MarshalX25519PublicKey(b []byte) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
-}
+	// Name is the human-readable name that identifies this certificate.
+	Name() string
 
-// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key
-func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key})
-}
+	// Networks is a list of ip addresses and network sizes assigned to this certificate.
+	// If IsCA is true then certificates signed by this CA can only have ip addresses and
+	// networks that are contained by an entry in this list.
+	Networks() []netip.Prefix
 
-func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var expectedLen int
-	var curve Curve
-	switch k.Type {
-	case X25519PublicKeyBanner:
-		expectedLen = 32
-		curve = Curve_CURVE25519
-	case P256PublicKeyBanner:
-		// Uncompressed
-		expectedLen = 65
-		curve = Curve_P256
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner")
-	}
-	if len(k.Bytes) != expectedLen {
-		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
-	}
-	return k.Bytes, r, curve, nil
-}
+	// UnsafeNetworks is a list of networks that this host can act as an unsafe router for.
+	// If IsCA is true then certificates signed by this CA can only have networks that are
+	// contained by an entry in this list.
+	UnsafeNetworks() []netip.Prefix
 
-// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b
-// or an error on failure
-func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != X25519PublicKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner")
-	}
-	if len(k.Bytes) != publicKeyLen {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key")
-	}
+	// Groups is a list of identities that can be used to write more general firewall rule
+	// definitions.
+	// If IsCA is true then certificates signed by this CA can only use groups that are
+	// in this list.
+	Groups() []string
 
-	return k.Bytes, r, nil
-}
+	// IsCA signifies if this is a certificate authority (true) or a host certificate (false).
+	// It is invalid to use a CA certificate as a host certificate.
+	IsCA() bool
 
-// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b
-// or an error on failure
-func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != Ed25519PublicKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner")
-	}
-	if len(k.Bytes) != ed25519.PublicKeySize {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key")
-	}
+	// NotBefore is the time at which this certificate becomes valid.
+	// If IsCA is true then certificate signed by this CA can not have a time before this.
+	NotBefore() time.Time
 
-	return k.Bytes, r, nil
-}
+	// NotAfter is the time at which this certificate becomes invalid.
+	// If IsCA is true then certificate signed by this CA can not have a time after this.
+	NotAfter() time.Time
 
-// Sign signs a nebula cert with the provided private key
-func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error {
-	if curve != nc.Details.Curve {
-		return fmt.Errorf("curve in cert and private key supplied don't match")
-	}
+	// Issuer is the fingerprint of the CA that signed this certificate.
+	// If IsCA is true then this will be empty.
+	Issuer() string
 
-	b, err := proto.Marshal(nc.getRawDetails())
-	if err != nil {
-		return err
-	}
+	// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
+	PublicKey() []byte
 
-	var sig []byte
+	// Curve identifies which curve was used for the PublicKey and Signature.
+	Curve() Curve
 
-	switch curve {
-	case Curve_CURVE25519:
-		signer := ed25519.PrivateKey(key)
-		sig = ed25519.Sign(signer, b)
-	case Curve_P256:
-		signer := &ecdsa.PrivateKey{
-			PublicKey: ecdsa.PublicKey{
-				Curve: elliptic.P256(),
-			},
-			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
-			D: new(big.Int).SetBytes(key),
-		}
-		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
-		signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
+	// Signature is the cryptographic seal for all the details of this certificate.
+	// CheckSignature can be used to verify that the details of this certificate are valid.
+	Signature() []byte
 
-		// We need to hash first for ECDSA
-		// - https://pkg.go.dev/crypto/ecdsa#SignASN1
-		hashed := sha256.Sum256(b)
-		sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
-		if err != nil {
-			return err
-		}
-	default:
-		return fmt.Errorf("invalid curve: %s", nc.Details.Curve)
-	}
+	// CheckSignature will check that the certificate Signature() matches the
+	// computed signature. A true result means this certificate has not been tampered with.
+	CheckSignature(signingPublicKey []byte) bool
 
-	nc.Signature = sig
-	return nil
-}
+	// Fingerprint returns the hex encoded sha256 sum of the certificate.
+	// This acts as a unique fingerprint and can be used to blocklist certificates.
+	Fingerprint() (string, error)
 
-// CheckSignature verifies the signature against the provided public key
-func (nc *NebulaCertificate) CheckSignature(key []byte) bool {
-	b, err := proto.Marshal(nc.getRawDetails())
-	if err != nil {
-		return false
-	}
-	switch nc.Details.Curve {
-	case Curve_CURVE25519:
-		return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature)
-	case Curve_P256:
-		x, y := elliptic.Unmarshal(elliptic.P256(), key)
-		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
-		hashed := sha256.Sum256(b)
-		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature)
-	default:
-		return false
-	}
-}
+	// Expired tests if the certificate is valid for the provided time.
+	Expired(t time.Time) bool
 
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool {
-	if !useCache {
-		return nc.CheckSignature(key)
-	}
+	// VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key.
+	VerifyPrivateKey(curve Curve, privateKey []byte) error
 
-	if v := nc.signatureVerified.Load(); v != nil {
-		return bytes.Equal(*v, key)
-	}
+	// Marshal will return the byte representation of this certificate
+	// This is primarily the format transmitted on the wire.
+	Marshal() ([]byte, error)
 
-	verified := nc.CheckSignature(key)
-	if verified {
-		keyCopy := make([]byte, len(key))
-		copy(keyCopy, key)
-		nc.signatureVerified.Store(&keyCopy)
-	}
+	// MarshalForHandshakes prepares the bytes needed to use directly in a handshake
+	MarshalForHandshakes() ([]byte, error)
 
-	return verified
-}
+	// MarshalPEM will return a PEM encoded representation of this certificate
+	// This is primarily the format stored on disk
+	MarshalPEM() ([]byte, error)
 
-// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false
-func (nc *NebulaCertificate) Expired(t time.Time) bool {
-	return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
-}
+	// MarshalJSON will return the json representation of this certificate
+	MarshalJSON() ([]byte, error)
 
-// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
-	return nc.verify(t, ncp, false)
-}
+	// String will return a human-readable representation of this certificate
+	String() string
 
-// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-//
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) {
-	return nc.verify(t, ncp, true)
+	// Copy creates a copy of the certificate
+	Copy() Certificate
 }
 
-// ResetCache resets the cache used by VerifyWithCache.
-func (nc *NebulaCertificate) ResetCache() {
-	nc.sha256sum.Store(nil)
-	nc.signatureVerified.Store(nil)
+// CachedCertificate represents a verified certificate with some cached fields to improve
+// performance.
+type CachedCertificate struct {
+	Certificate       Certificate
+	InvertedGroups    map[string]struct{}
+	Fingerprint       string
+	signerFingerprint string
 }
 
-// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) {
-	if ncp.isBlocklistedWithCache(nc, useCache) {
-		return false, ErrBlockListed
-	}
-
-	signer, err := ncp.GetCAForCert(nc)
-	if err != nil {
-		return false, err
-	}
-
-	if signer.Expired(t) {
-		return false, ErrRootExpired
-	}
-
-	if nc.Expired(t) {
-		return false, ErrExpired
-	}
-
-	if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) {
-		return false, ErrSignatureMismatch
-	}
-
-	if err := nc.CheckRootConstrains(signer); err != nil {
-		return false, err
-	}
-
-	return true, nil
+func (cc *CachedCertificate) String() string {
+	return cc.Certificate.String()
 }
 
-// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets)
-func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error {
-	// Make sure this cert wasn't valid before the root
-	if signer.Details.NotAfter.Before(nc.Details.NotAfter) {
-		return fmt.Errorf("certificate expires after signing certificate")
-	}
-
-	// Make sure this cert isn't valid after the root
-	if signer.Details.NotBefore.After(nc.Details.NotBefore) {
-		return fmt.Errorf("certificate is valid before the signing certificate")
-	}
-
-	// If the signer has a limited set of groups make sure the cert only contains a subset
-	if len(signer.Details.InvertedGroups) > 0 {
-		for _, g := range nc.Details.Groups {
-			if _, ok := signer.Details.InvertedGroups[g]; !ok {
-				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
-			}
-		}
-	}
-
-	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
-	if len(signer.Details.Ips) > 0 {
-		for _, ip := range nc.Details.Ips {
-			if !netMatch(ip, signer.Details.Ips) {
-				return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String())
-			}
-		}
-	}
-
-	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
-	if len(signer.Details.Subnets) > 0 {
-		for _, subnet := range nc.Details.Subnets {
-			if !netMatch(subnet, signer.Details.Subnets) {
-				return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet)
-			}
-		}
+// Recombine will attempt to unmarshal a certificate received in a handshake.
+// Handshakes save space by placing the peers public key in a different part of the packet, we have to
+// reassemble the actual certificate structure with that in mind.
+func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) {
+	if publicKey == nil {
+		return nil, ErrNoPeerStaticKey
 	}
 
-	return nil
-}
-
-// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
-func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error {
-	if curve != nc.Details.Curve {
-		return fmt.Errorf("curve in cert and private key supplied don't match")
+	if rawCertBytes == nil {
+		return nil, ErrNoPayload
 	}
-	if nc.Details.IsCA {
-		switch curve {
-		case Curve_CURVE25519:
-			// the call to PublicKey below will panic slice bounds out of range otherwise
-			if len(key) != ed25519.PrivateKeySize {
-				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-			}
 
-			if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
-				return fmt.Errorf("public key in cert and private key supplied don't match")
-			}
-		case Curve_P256:
-			privkey, err := ecdh.P256().NewPrivateKey(key)
-			if err != nil {
-				return fmt.Errorf("cannot parse private key as P256")
-			}
-			pub := privkey.PublicKey().Bytes()
-			if !bytes.Equal(pub, nc.Details.PublicKey) {
-				return fmt.Errorf("public key in cert and private key supplied don't match")
-			}
-		default:
-			return fmt.Errorf("invalid curve: %s", curve)
-		}
-		return nil
-	}
+	var c Certificate
+	var err error
 
-	var pub []byte
-	switch curve {
-	case Curve_CURVE25519:
-		var err error
-		pub, err = curve25519.X25519(key, curve25519.Basepoint)
-		if err != nil {
-			return err
-		}
-	case Curve_P256:
-		privkey, err := ecdh.P256().NewPrivateKey(key)
-		if err != nil {
-			return err
-		}
-		pub = privkey.PublicKey().Bytes()
+	switch v {
+	// Implementations must ensure the result is a valid cert!
+	case VersionPre1, Version1:
+		c, err = unmarshalCertificateV1(rawCertBytes, publicKey)
+	case Version2:
+		c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve)
 	default:
-		return fmt.Errorf("invalid curve: %s", curve)
-	}
-	if !bytes.Equal(pub, nc.Details.PublicKey) {
-		return fmt.Errorf("public key in cert and private key supplied don't match")
-	}
-
-	return nil
-}
-
-// String will return a pretty printed representation of a nebula cert
-func (nc *NebulaCertificate) String() string {
-	if nc == nil {
-		return "NebulaCertificate {}\n"
+		//TODO: CERT-V2 make a static var
+		return nil, fmt.Errorf("unknown certificate version %d", v)
 	}
 
-	s := "NebulaCertificate {\n"
-	s += "\tDetails {\n"
-	s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name)
-
-	if len(nc.Details.Ips) > 0 {
-		s += "\t\tIps: [\n"
-		for _, ip := range nc.Details.Ips {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tIps: []\n"
-	}
-
-	if len(nc.Details.Subnets) > 0 {
-		s += "\t\tSubnets: [\n"
-		for _, ip := range nc.Details.Subnets {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tSubnets: []\n"
-	}
-
-	if len(nc.Details.Groups) > 0 {
-		s += "\t\tGroups: [\n"
-		for _, g := range nc.Details.Groups {
-			s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tGroups: []\n"
-	}
-
-	s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore)
-	s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter)
-	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA)
-	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer)
-	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey)
-	s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve)
-	s += "\t}\n"
-	fp, err := nc.Sha256Sum()
-	if err == nil {
-		s += fmt.Sprintf("\tFingerprint: %s\n", fp)
-	}
-	s += fmt.Sprintf("\tSignature: %x\n", nc.Signature)
-	s += "}"
-
-	return s
-}
-
-// getRawDetails marshals the raw details into protobuf ready struct
-func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails {
-	rd := &RawNebulaCertificateDetails{
-		Name:      nc.Details.Name,
-		Groups:    nc.Details.Groups,
-		NotBefore: nc.Details.NotBefore.Unix(),
-		NotAfter:  nc.Details.NotAfter.Unix(),
-		PublicKey: make([]byte, len(nc.Details.PublicKey)),
-		IsCA:      nc.Details.IsCA,
-		Curve:     nc.Details.Curve,
-	}
-
-	for _, ipNet := range nc.Details.Ips {
-		rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask))
-	}
-
-	for _, ipNet := range nc.Details.Subnets {
-		rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask))
-	}
-
-	copy(rd.PublicKey, nc.Details.PublicKey[:])
-
-	// I know, this is terrible
-	rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer)
-
-	return rd
-}
-
-// Marshal will marshal a nebula cert into a protobuf byte array
-func (nc *NebulaCertificate) Marshal() ([]byte, error) {
-	rc := RawNebulaCertificate{
-		Details:   nc.getRawDetails(),
-		Signature: nc.Signature,
-	}
-
-	return proto.Marshal(&rc)
-}
-
-// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result
-func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) {
-	b, err := nc.Marshal()
 	if err != nil {
 		return nil, err
 	}
-	return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil
-}
-
-// Sha256Sum calculates a sha-256 sum of the marshaled certificate
-func (nc *NebulaCertificate) Sha256Sum() (string, error) {
-	b, err := nc.Marshal()
-	if err != nil {
-		return "", err
-	}
-
-	sum := sha256.Sum256(b)
-	return hex.EncodeToString(sum[:]), nil
-}
-
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) {
-	if !useCache {
-		return nc.Sha256Sum()
-	}
-
-	if s := nc.sha256sum.Load(); s != nil {
-		return *s, nil
-	}
-	s, err := nc.Sha256Sum()
-	if err != nil {
-		return s, err
-	}
-
-	nc.sha256sum.Store(&s)
-	return s, nil
-}
-
-func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
-	toString := func(ips []*net.IPNet) []string {
-		s := []string{}
-		for _, ip := range ips {
-			s = append(s, ip.String())
-		}
-		return s
-	}
-
-	fp, _ := nc.Sha256Sum()
-	jc := m{
-		"details": m{
-			"name":      nc.Details.Name,
-			"ips":       toString(nc.Details.Ips),
-			"subnets":   toString(nc.Details.Subnets),
-			"groups":    nc.Details.Groups,
-			"notBefore": nc.Details.NotBefore,
-			"notAfter":  nc.Details.NotAfter,
-			"publicKey": fmt.Sprintf("%x", nc.Details.PublicKey),
-			"isCa":      nc.Details.IsCA,
-			"issuer":    nc.Details.Issuer,
-			"curve":     nc.Details.Curve.String(),
-		},
-		"fingerprint": fp,
-		"signature":   fmt.Sprintf("%x", nc.Signature),
-	}
-	return json.Marshal(jc)
-}
-
-//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
-//	r, err := nc.Marshal()
-//	if err != nil {
-//		//TODO
-//		return nil
-//	}
-//
-//	c, err := UnmarshalNebulaCertificate(r)
-//	return c
-//}
-
-func (nc *NebulaCertificate) Copy() *NebulaCertificate {
-	c := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           nc.Details.Name,
-			Groups:         make([]string, len(nc.Details.Groups)),
-			Ips:            make([]*net.IPNet, len(nc.Details.Ips)),
-			Subnets:        make([]*net.IPNet, len(nc.Details.Subnets)),
-			NotBefore:      nc.Details.NotBefore,
-			NotAfter:       nc.Details.NotAfter,
-			PublicKey:      make([]byte, len(nc.Details.PublicKey)),
-			IsCA:           nc.Details.IsCA,
-			Issuer:         nc.Details.Issuer,
-			InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
-		},
-		Signature: make([]byte, len(nc.Signature)),
-	}
-
-	copy(c.Signature, nc.Signature)
-	copy(c.Details.Groups, nc.Details.Groups)
-	copy(c.Details.PublicKey, nc.Details.PublicKey)
-
-	for i, p := range nc.Details.Ips {
-		c.Details.Ips[i] = &net.IPNet{
-			IP:   make(net.IP, len(p.IP)),
-			Mask: make(net.IPMask, len(p.Mask)),
-		}
-		copy(c.Details.Ips[i].IP, p.IP)
-		copy(c.Details.Ips[i].Mask, p.Mask)
-	}
-
-	for i, p := range nc.Details.Subnets {
-		c.Details.Subnets[i] = &net.IPNet{
-			IP:   make(net.IP, len(p.IP)),
-			Mask: make(net.IPMask, len(p.Mask)),
-		}
-		copy(c.Details.Subnets[i].IP, p.IP)
-		copy(c.Details.Subnets[i].Mask, p.Mask)
-	}
-
-	for g := range nc.Details.InvertedGroups {
-		c.Details.InvertedGroups[g] = struct{}{}
-	}
-
-	return c
-}
-
-func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
-	for _, net := range rootIps {
-		if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
-			return true
-		}
-	}
-
-	return false
-}
-
-func maskContains(caMask, certMask net.IPMask) bool {
-	caM := maskTo4(caMask)
-	cM := maskTo4(certMask)
-	// Make sure forcing to ipv4 didn't nuke us
-	if caM == nil || cM == nil {
-		return false
-	}
-
-	// Make sure the cert mask is not greater than the ca mask
-	for i := 0; i < len(caMask); i++ {
-		if caM[i] > cM[i] {
-			return false
-		}
-	}
-
-	return true
-}
-
-func maskTo4(ip net.IPMask) net.IPMask {
-	if len(ip) == net.IPv4len {
-		return ip
-	}
 
-	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
-		return ip[12:16]
+	if c.Curve() != curve {
+		return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
 	}
 
-	return nil
-}
-
-func isZeros(b []byte) bool {
-	for i := 0; i < len(b); i++ {
-		if b[i] != 0 {
-			return false
-		}
-	}
-	return true
-}
-
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, net.IPv4len)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
+	return c, nil
 }

+ 0 - 1230
cert/cert_test.go

@@ -1,1230 +0,0 @@
-package cert
-
-import (
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/elliptic"
-	"crypto/rand"
-	"fmt"
-	"io"
-	"net"
-	"testing"
-	"time"
-
-	"github.com/slackhq/nebula/test"
-	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-	"google.golang.org/protobuf/proto"
-)
-
-func TestMarshalingNebulaCertificate(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-
-	nc2, err := UnmarshalNebulaCertificate(b)
-	assert.Nil(t, err)
-
-	assert.Equal(t, nc.Signature, nc2.Signature)
-	assert.Equal(t, nc.Details.Name, nc2.Details.Name)
-	assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore)
-	assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter)
-	assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey)
-	assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA)
-
-	// IP byte arrays can be 4 or 16 in length so we have to go this route
-	assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips))
-	for i, wIp := range nc.Details.Ips {
-		assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String())
-	}
-
-	assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets))
-	for i, wIp := range nc.Details.Subnets {
-		assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String())
-	}
-
-	assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups)
-}
-
-func TestNebulaCertificate_Sign(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-	}
-
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	assert.Nil(t, err)
-	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(Curve_CURVE25519, priv))
-	assert.True(t, nc.CheckSignature(pub))
-
-	_, err = nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-}
-
-func TestNebulaCertificate_SignP256(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Curve:     Curve_P256,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-	}
-
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	assert.Nil(t, err)
-	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(Curve_P256, rawPriv))
-	assert.True(t, nc.CheckSignature(pub))
-
-	_, err = nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-}
-
-func TestNebulaCertificate_Expired(t *testing.T) {
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
-			NotAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
-		},
-	}
-
-	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
-	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
-	assert.False(t, nc.Expired(time.Now()))
-}
-
-func TestNebulaCertificate_MarshalJSON(t *testing.T) {
-	time.Local = time.UTC
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
-			NotAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.MarshalJSON()
-	assert.Nil(t, err)
-	assert.Equal(
-		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
-		string(b),
-	)
-}
-
-func TestNebulaCertificate_Verify(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	h, err := ca.Sha256Sum()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.CAs[h] = ca
-
-	f, err := c.Sha256Sum()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is expired")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	h, err := ca.Sha256Sum()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.CAs[h] = ca
-
-	f, err := c.Sha256Sum()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is expired")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_IPs(t *testing.T) {
-	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
-	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	// ip is outside the network
-	cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}}
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_Subnets(t *testing.T) {
-	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
-	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	// ip is outside the network
-	cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}}
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := x25519Keypair()
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(Curve_P256, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := p256Keypair()
-	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNewCAPoolFromBytes(t *testing.T) {
-	noNewLines := `
-# Current provisional, Remove once everything moves over to the real root.
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-# root-ca01
------BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
------END NEBULA CERTIFICATE-----
-`
-
-	withNewLines := `
-# Current provisional, Remove once everything moves over to the real root.
-
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-
-# root-ca01
-
-
------BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
------END NEBULA CERTIFICATE-----
-
-`
-
-	expired := `
-# expired certificate
------BEGIN NEBULA CERTIFICATE-----
-CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4
-vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie
-WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
------END NEBULA CERTIFICATE-----
-`
-
-	p256 := `
-# p256 certificate
------BEGIN NEBULA CERTIFICATE-----
-CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
-6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H
-76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
-IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
------END NEBULA CERTIFICATE-----
-`
-
-	rootCA := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula root ca",
-		},
-	}
-
-	rootCA01 := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula root ca 01",
-		},
-	}
-
-	rootCAP256 := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula P256 test",
-		},
-	}
-
-	p, err := NewCAPoolFromBytes([]byte(noNewLines))
-	assert.Nil(t, err)
-	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-
-	pp, err := NewCAPoolFromBytes([]byte(withNewLines))
-	assert.Nil(t, err)
-	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-
-	// expired cert, no valid certs
-	ppp, err := NewCAPoolFromBytes([]byte(expired))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
-
-	// expired cert, with valid certs
-	pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
-	assert.Equal(t, len(pppp.CAs), 3)
-
-	ppppp, err := NewCAPoolFromBytes([]byte(p256))
-	assert.Nil(t, err)
-	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
-	assert.Equal(t, len(ppppp.CAs), 1)
-}
-
-func appendByteSlices(b ...[]byte) []byte {
-	retSlice := []byte{}
-	for _, v := range b {
-		retSlice = append(retSlice, v...)
-	}
-	return retSlice
-}
-
-func TestUnmrshalCertPEM(t *testing.T) {
-	goodCert := []byte(`
-# A good cert
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-`)
-	badBanner := []byte(`# A bad banner
------BEGIN NOT A NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NOT A NEBULA CERTIFICATE-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
--END NEBULA CERTIFICATE----`)
-
-	certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
-
-	// Success test case
-	cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle)
-	assert.NotNil(t, cert)
-	assert.Equal(t, rest, append(badBanner, invalidPem...))
-	assert.Nil(t, err)
-
-	// Fail due to invalid banner.
-	cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
-	assert.Nil(t, cert)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
-	assert.Nil(t, cert)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalSigningPrivateKey(t *testing.T) {
-	privKey := []byte(`# A good key
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	privP256Key := []byte(`# A good key
------BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA ECDSA P256 PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NOT A NEBULA PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
--END NEBULA ED25519 PRIVATE KEY-----`)
-
-	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle)
-	assert.Len(t, k, 64)
-	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-	assert.Nil(t, err)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
-	passphrase := []byte("DO NOT USE THIS KEY")
-	privKey := []byte(`# A good key
------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
------END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A key which, once decrypted, is too short
------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
-k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
-GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
-rQr3bdH3Oy/WiYU=
------END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner (not encrypted)
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
-XgLvodMXZJuaFPssp+WwtA==
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
--END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-
-	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
-	assert.Nil(t, err)
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Len(t, k, 64)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-
-	// Fail due to short key
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-
-	// Fail due to invalid banner
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to invalid passphrase
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
-	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, []byte{})
-}
-
-func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
-	// Having proved that decryption works correctly above, we can test the
-	// encryption function produces a value which can be decrypted
-	passphrase := []byte("passphrase")
-	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
-	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
-	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
-	assert.Nil(t, err)
-
-	// Verify the "key" can be decrypted successfully
-	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
-	assert.Len(t, k, 64)
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Equal(t, rest, []byte{})
-	assert.Nil(t, err)
-
-	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
-}
-
-func TestUnmarshalPrivateKey(t *testing.T) {
-	privKey := []byte(`# A good key
------BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA X25519 PRIVATE KEY-----
-`)
-	privP256Key := []byte(`# A good key
------BEGIN NEBULA P256 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA P256 PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA X25519 PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA X25519 PRIVATE KEY-----`)
-
-	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalPrivateKey(keyBundle)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-	assert.Nil(t, err)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalEd25519PublicKey(t *testing.T) {
-	pubKey := []byte(`# A good key
------BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA ED25519 PUBLIC KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA ED25519 PUBLIC KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PUBLIC KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA ED25519 PUBLIC KEY-----`)
-
-	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, err := UnmarshalEd25519PublicKey(keyBundle)
-	assert.Equal(t, len(k), 32)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-
-	// Fail due to short key
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key")
-
-	// Fail due to invalid banner
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner")
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalX25519PublicKey(t *testing.T) {
-	pubKey := []byte(`# A good key
------BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA X25519 PUBLIC KEY-----
-`)
-	pubP256Key := []byte(`# A good key
------BEGIN NEBULA P256 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
-AAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA P256 PUBLIC KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA X25519 PUBLIC KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PUBLIC KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA X25519 PUBLIC KEY-----`)
-
-	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalPublicKey(keyBundle)
-	assert.Equal(t, len(k), 32)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Equal(t, len(k), 65)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner")
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-// Ensure that upgrading the protobuf library does not change how certificates
-// are marshalled, since this would break signature verification
-func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
-	before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
-	after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-	assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
-
-	b, err = proto.Marshal(nc.getRawDetails())
-	assert.Nil(t, err)
-	//t.Log("Raw cert size:", len(b))
-	assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
-}
-
-func TestNebulaCertificate_Copy(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	cc := c.Copy()
-
-	test.AssertDeepCopyEqual(t, c, cc)
-}
-
-func TestUnmarshalNebulaCertificate(t *testing.T) {
-	// Test that we don't panic with an invalid certificate (#332)
-	data := []byte("\x98\x00\x00")
-	_, err := UnmarshalNebulaCertificate(data)
-	assert.EqualError(t, err, "encoded Details was nil")
-}
-
-func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(Curve_CURVE25519, priv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, priv, nil
-}
-
-func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			Curve:          Curve_P256,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(Curve_P256, rawPriv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, rawPriv, nil
-}
-
-func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	if len(groups) == 0 {
-		groups = []string{"test-group1", "test-group2", "test-group3"}
-	}
-
-	if len(ips) == 0 {
-		ips = []*net.IPNet{
-			{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
-			{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
-			{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
-		}
-	}
-
-	if len(subnets) == 0 {
-		subnets = []*net.IPNet{
-			{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
-			{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
-			{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
-		}
-	}
-
-	var pub, rawPriv []byte
-
-	switch ca.Details.Curve {
-	case Curve_CURVE25519:
-		pub, rawPriv = x25519Keypair()
-	case Curve_P256:
-		pub, rawPriv = p256Keypair()
-	default:
-		return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "testing",
-			Ips:            ips,
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Curve:          ca.Details.Curve,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	return nc, pub, rawPriv, nil
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
-func p256Keypair() ([]byte, []byte) {
-	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
-	if err != nil {
-		panic(err)
-	}
-	pubkey := privkey.PublicKey()
-	return pubkey.Bytes(), privkey.Bytes()
-}

+ 489 - 0
cert/cert_v1.go

@@ -0,0 +1,489 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/binary"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"google.golang.org/protobuf/proto"
+)
+
+const publicKeyLen = 32
+
+type certificateV1 struct {
+	details   detailsV1
+	signature []byte
+}
+
+type detailsV1 struct {
+	name           string
+	networks       []netip.Prefix
+	unsafeNetworks []netip.Prefix
+	groups         []string
+	notBefore      time.Time
+	notAfter       time.Time
+	publicKey      []byte
+	isCA           bool
+	issuer         string
+
+	curve Curve
+}
+
+type m = map[string]any
+
+func (c *certificateV1) Version() Version {
+	return Version1
+}
+
+func (c *certificateV1) Curve() Curve {
+	return c.details.curve
+}
+
+func (c *certificateV1) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV1) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV1) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV1) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV1) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV1) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV1) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV1) PublicKey() []byte {
+	return c.details.publicKey
+}
+
+func (c *certificateV1) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV1) Fingerprint() (string, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return "", err
+	}
+
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV1) CheckSignature(key []byte) bool {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return false
+	}
+	switch c.details.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV1) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.details.curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+			}
+
+			if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return fmt.Errorf("cannot parse private key as P256: %w", err)
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.details.publicKey) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return err
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return err
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.details.publicKey) {
+		return fmt.Errorf("public key in cert and private key supplied don't match")
+	}
+
+	return nil
+}
+
+// getRawDetails marshals the raw details into protobuf ready struct
+func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
+	rd := &RawNebulaCertificateDetails{
+		Name:      c.details.name,
+		Groups:    c.details.groups,
+		NotBefore: c.details.notBefore.Unix(),
+		NotAfter:  c.details.notAfter.Unix(),
+		PublicKey: make([]byte, len(c.details.publicKey)),
+		IsCA:      c.details.isCA,
+		Curve:     c.details.curve,
+	}
+
+	for _, ipNet := range c.details.networks {
+		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
+		rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
+	}
+
+	for _, ipNet := range c.details.unsafeNetworks {
+		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
+		rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
+	}
+
+	copy(rd.PublicKey, c.details.publicKey[:])
+
+	// I know, this is terrible
+	rd.Issuer, _ = hex.DecodeString(c.details.issuer)
+
+	return rd
+}
+
+func (c *certificateV1) String() string {
+	b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
+	pubKey := c.details.publicKey
+	c.details.publicKey = nil
+	rawCertNoKey, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	c.details.publicKey = pubKey
+	return rawCertNoKey, nil
+}
+
+func (c *certificateV1) Marshal() ([]byte, error) {
+	rc := RawNebulaCertificate{
+		Details:   c.getRawDetails(),
+		Signature: c.signature,
+	}
+
+	return proto.Marshal(&rc)
+}
+
+func (c *certificateV1) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
+}
+
+func (c *certificateV1) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.marshalJSON())
+}
+
+func (c *certificateV1) marshalJSON() m {
+	fp, _ := c.Fingerprint()
+	return m{
+		"version": Version1,
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"publicKey":      fmt.Sprintf("%x", c.details.publicKey),
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+			"curve":          c.details.curve.String(),
+		},
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}
+}
+
+func (c *certificateV1) Copy() Certificate {
+	nc := &certificateV1{
+		details: detailsV1{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			publicKey: make([]byte, len(c.details.publicKey)),
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+			curve:     c.details.curve,
+		},
+		signature: make([]byte, len(c.signature)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.signature, c.signature)
+	copy(nc.details.publicKey, c.details.publicKey)
+
+	return nc
+}
+
+func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV1{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		publicKey:      t.PublicKey,
+		isCA:           t.IsCA,
+		curve:          t.Curve,
+		issuer:         t.issuer,
+	}
+
+	return c.validate()
+}
+
+func (c *certificateV1) validate() error {
+	// Empty names are allowed
+
+	if len(c.details.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
+	// Continue to allow this behavior
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
+	}
+
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+	}
+
+	// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
+	// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
+	// unsafe networks would result in a different signature.
+
+	return nil
+}
+
+func (c *certificateV1) marshalForSigning() ([]byte, error) {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return nil, err
+	}
+	return b, nil
+}
+
+func (c *certificateV1) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
+// if the publicKey is provided here then it is not required to be present in `b`
+func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rc RawNebulaCertificate
+	err := proto.Unmarshal(b, &rc)
+	if err != nil {
+		return nil, err
+	}
+
+	if rc.Details == nil {
+		return nil, fmt.Errorf("encoded Details was nil")
+	}
+
+	if len(rc.Details.Ips)%2 != 0 {
+		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
+	}
+
+	if len(rc.Details.Subnets)%2 != 0 {
+		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
+	}
+
+	nc := certificateV1{
+		details: detailsV1{
+			name:           rc.Details.Name,
+			groups:         make([]string, len(rc.Details.Groups)),
+			networks:       make([]netip.Prefix, len(rc.Details.Ips)/2),
+			unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
+			notBefore:      time.Unix(rc.Details.NotBefore, 0),
+			notAfter:       time.Unix(rc.Details.NotAfter, 0),
+			publicKey:      make([]byte, len(rc.Details.PublicKey)),
+			isCA:           rc.Details.IsCA,
+			curve:          rc.Details.Curve,
+		},
+		signature: make([]byte, len(rc.Signature)),
+	}
+
+	copy(nc.signature, rc.Signature)
+	copy(nc.details.groups, rc.Details.Groups)
+	nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
+
+	if len(publicKey) > 0 {
+		nc.details.publicKey = publicKey
+	}
+
+	copy(nc.details.publicKey, rc.Details.PublicKey)
+
+	var ip netip.Addr
+	for i, rawIp := range rc.Details.Ips {
+		if i%2 == 0 {
+			ip = int2addr(rawIp)
+		} else {
+			ones, _ := net.IPMask(int2ip(rawIp)).Size()
+			nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
+		}
+	}
+
+	for i, rawIp := range rc.Details.Subnets {
+		if i%2 == 0 {
+			ip = int2addr(rawIp)
+		} else {
+			ones, _ := net.IPMask(int2ip(rawIp)).Size()
+			nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
+		}
+	}
+
+	err = nc.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return &nc, nil
+}
+
+func ip2int(ip []byte) uint32 {
+	if len(ip) == 16 {
+		return binary.BigEndian.Uint32(ip[12:16])
+	}
+	return binary.BigEndian.Uint32(ip)
+}
+
+func int2ip(nn uint32) net.IP {
+	ip := make(net.IP, net.IPv4len)
+	binary.BigEndian.PutUint32(ip, nn)
+	return ip
+}
+
+func addr2int(addr netip.Addr) uint32 {
+	b := addr.Unmap().As4()
+	return binary.BigEndian.Uint32(b[:])
+}
+
+func int2addr(nn uint32) netip.Addr {
+	ip := [4]byte{}
+	binary.BigEndian.PutUint32(ip[:], nn)
+	return netip.AddrFrom4(ip).Unmap()
+}

+ 111 - 111
cert/cert.pb.go → cert/cert_v1.pb.go

@@ -1,8 +1,8 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
-// 	protoc-gen-go v1.30.0
+// 	protoc-gen-go v1.34.2
 // 	protoc        v3.21.5
-// source: cert.proto
+// source: cert_v1.proto
 
 package cert
 
@@ -50,11 +50,11 @@ func (x Curve) String() string {
 }
 
 func (Curve) Descriptor() protoreflect.EnumDescriptor {
-	return file_cert_proto_enumTypes[0].Descriptor()
+	return file_cert_v1_proto_enumTypes[0].Descriptor()
 }
 
 func (Curve) Type() protoreflect.EnumType {
-	return &file_cert_proto_enumTypes[0]
+	return &file_cert_v1_proto_enumTypes[0]
 }
 
 func (x Curve) Number() protoreflect.EnumNumber {
@@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber {
 
 // Deprecated: Use Curve.Descriptor instead.
 func (Curve) EnumDescriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{0}
+	return file_cert_v1_proto_rawDescGZIP(), []int{0}
 }
 
 type RawNebulaCertificate struct {
@@ -78,7 +78,7 @@ type RawNebulaCertificate struct {
 func (x *RawNebulaCertificate) Reset() {
 	*x = RawNebulaCertificate{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[0]
+		mi := &file_cert_v1_proto_msgTypes[0]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string {
 func (*RawNebulaCertificate) ProtoMessage() {}
 
 func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[0]
+	mi := &file_cert_v1_proto_msgTypes[0]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
 func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{0}
+	return file_cert_v1_proto_rawDescGZIP(), []int{0}
 }
 
 func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
@@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct {
 func (x *RawNebulaCertificateDetails) Reset() {
 	*x = RawNebulaCertificateDetails{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[1]
+		mi := &file_cert_v1_proto_msgTypes[1]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string {
 func (*RawNebulaCertificateDetails) ProtoMessage() {}
 
 func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[1]
+	mi := &file_cert_v1_proto_msgTypes[1]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
 func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{1}
+	return file_cert_v1_proto_rawDescGZIP(), []int{1}
 }
 
 func (x *RawNebulaCertificateDetails) GetName() string {
@@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct {
 func (x *RawNebulaEncryptedData) Reset() {
 	*x = RawNebulaEncryptedData{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[2]
+		mi := &file_cert_v1_proto_msgTypes[2]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string {
 func (*RawNebulaEncryptedData) ProtoMessage() {}
 
 func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[2]
+	mi := &file_cert_v1_proto_msgTypes[2]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
 func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{2}
+	return file_cert_v1_proto_rawDescGZIP(), []int{2}
 }
 
 func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
@@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct {
 func (x *RawNebulaEncryptionMetadata) Reset() {
 	*x = RawNebulaEncryptionMetadata{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[3]
+		mi := &file_cert_v1_proto_msgTypes[3]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string {
 func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
 
 func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[3]
+	mi := &file_cert_v1_proto_msgTypes[3]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
 func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{3}
+	return file_cert_v1_proto_rawDescGZIP(), []int{3}
 }
 
 func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
@@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct {
 func (x *RawNebulaArgon2Parameters) Reset() {
 	*x = RawNebulaArgon2Parameters{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[4]
+		mi := &file_cert_v1_proto_msgTypes[4]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string {
 func (*RawNebulaArgon2Parameters) ProtoMessage() {}
 
 func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[4]
+	mi := &file_cert_v1_proto_msgTypes[4]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
 func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{4}
+	return file_cert_v1_proto_rawDescGZIP(), []int{4}
 }
 
 func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
@@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte {
 	return nil
 }
 
-var File_cert_proto protoreflect.FileDescriptor
-
-var file_cert_proto_rawDesc = []byte{
-	0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65,
-	0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
-	0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65,
-	0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74,
-	0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
-	0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
-	0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
-	0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
-	0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
-	0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
-	0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
-	0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53,
-	0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75,
-	0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
-	0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a,
-	0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03,
-	0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e,
-	0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e,
-	0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69,
-	0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c,
-	0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
-	0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
-	0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
-	0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e,
-	0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63,
-	0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
-	0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12,
-	0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
-	0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72,
-	0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12,
-	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74,
-	0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65,
-	0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61,
-	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
-	0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
-	0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72,
-	0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61,
-	0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f,
-	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
-	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52,
-	0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72,
-	0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
-	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12,
-	0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05,
-	0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d,
-	0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72,
-	0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
-	0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
-	0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
-	0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69,
-	0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
-	0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65,
-	0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00,
-	0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69,
-	0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71,
-	0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72,
-	0x6f, 0x74, 0x6f, 0x33,
+var File_cert_v1_proto protoreflect.FileDescriptor
+
+var file_cert_v1_proto_rawDesc = []byte{
+	0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
+	0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a,
+	0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
+	0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c,
+	0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69,
+	0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53,
+	0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77,
+	0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74,
+	0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
+	0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03,
+	0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18,
+	0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52,
+	0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75,
+	0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73,
+	0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20,
+	0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a,
+	0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03,
+	0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75,
+	0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50,
+	0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41,
+	0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06,
+	0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73,
+	0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20,
+	0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65,
+	0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e,
+	0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61,
+	0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
+	0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
+	0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
+	0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74,
+	0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65,
+	0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
+	0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
+	0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01,
+	0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c,
+	0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e,
+	0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28,
+	0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
+	0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65,
+	0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
+	0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
+	0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06,
+	0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65,
+	0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
+	0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c,
+	0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74,
+	0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72,
+	0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05,
+	0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75,
+	0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31,
+	0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a,
+	0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63,
+	0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62,
+	0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (
-	file_cert_proto_rawDescOnce sync.Once
-	file_cert_proto_rawDescData = file_cert_proto_rawDesc
+	file_cert_v1_proto_rawDescOnce sync.Once
+	file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc
 )
 
-func file_cert_proto_rawDescGZIP() []byte {
-	file_cert_proto_rawDescOnce.Do(func() {
-		file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData)
+func file_cert_v1_proto_rawDescGZIP() []byte {
+	file_cert_v1_proto_rawDescOnce.Do(func() {
+		file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData)
 	})
-	return file_cert_proto_rawDescData
+	return file_cert_v1_proto_rawDescData
 }
 
-var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
-var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
-var file_cert_proto_goTypes = []interface{}{
+var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
+var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
+var file_cert_v1_proto_goTypes = []any{
 	(Curve)(0),                          // 0: cert.Curve
 	(*RawNebulaCertificate)(nil),        // 1: cert.RawNebulaCertificate
 	(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
@@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{
 	(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
 	(*RawNebulaArgon2Parameters)(nil),   // 5: cert.RawNebulaArgon2Parameters
 }
-var file_cert_proto_depIdxs = []int32{
+var file_cert_v1_proto_depIdxs = []int32{
 	2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
 	0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
 	4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
@@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{
 	0, // [0:4] is the sub-list for field type_name
 }
 
-func init() { file_cert_proto_init() }
-func file_cert_proto_init() {
-	if File_cert_proto != nil {
+func init() { file_cert_v1_proto_init() }
+func file_cert_v1_proto_init() {
+	if File_cert_v1_proto != nil {
 		return
 	}
 	if !protoimpl.UnsafeEnabled {
-		file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaCertificate); i {
 			case 0:
 				return &v.state
@@ -549,7 +549,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaCertificateDetails); i {
 			case 0:
 				return &v.state
@@ -561,7 +561,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaEncryptedData); i {
 			case 0:
 				return &v.state
@@ -573,7 +573,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaEncryptionMetadata); i {
 			case 0:
 				return &v.state
@@ -585,7 +585,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaArgon2Parameters); i {
 			case 0:
 				return &v.state
@@ -602,19 +602,19 @@ func file_cert_proto_init() {
 	out := protoimpl.TypeBuilder{
 		File: protoimpl.DescBuilder{
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
-			RawDescriptor: file_cert_proto_rawDesc,
+			RawDescriptor: file_cert_v1_proto_rawDesc,
 			NumEnums:      1,
 			NumMessages:   5,
 			NumExtensions: 0,
 			NumServices:   0,
 		},
-		GoTypes:           file_cert_proto_goTypes,
-		DependencyIndexes: file_cert_proto_depIdxs,
-		EnumInfos:         file_cert_proto_enumTypes,
-		MessageInfos:      file_cert_proto_msgTypes,
+		GoTypes:           file_cert_v1_proto_goTypes,
+		DependencyIndexes: file_cert_v1_proto_depIdxs,
+		EnumInfos:         file_cert_v1_proto_enumTypes,
+		MessageInfos:      file_cert_v1_proto_msgTypes,
 	}.Build()
-	File_cert_proto = out.File
-	file_cert_proto_rawDesc = nil
-	file_cert_proto_goTypes = nil
-	file_cert_proto_depIdxs = nil
+	File_cert_v1_proto = out.File
+	file_cert_v1_proto_rawDesc = nil
+	file_cert_v1_proto_goTypes = nil
+	file_cert_v1_proto_depIdxs = nil
 }

+ 0 - 0
cert/cert.proto → cert/cert_v1.proto


+ 218 - 0
cert/cert_v1_test.go

@@ -0,0 +1,218 @@
+package cert
+
+import (
+	"fmt"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"google.golang.org/protobuf/proto"
+)
+
+func TestCertificateV1_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	require.NoError(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV1(b, nil)
+	require.NoError(t, err)
+
+	assert.Equal(t, Version1, nc.Version())
+	assert.Equal(t, Curve_CURVE25519, nc.Curve())
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV1_Expired(t *testing.T) {
+	nc := certificateV1{
+		details: detailsV1{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV1_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.MarshalJSON()
+	require.NoError(t, err)
+	assert.JSONEq(
+		t,
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
+		string(b),
+	)
+}
+
+func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	require.NoError(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	require.NoError(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	require.Error(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	require.NoError(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	require.Error(t, err)
+}
+
+func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	require.NoError(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	require.NoError(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	require.Error(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	require.NoError(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	require.Error(t, err)
+}
+
+// Ensure that upgrading the protobuf library does not change how certificates
+// are marshalled, since this would break signature verification
+func TestMarshalingCertificateV1Consistency(t *testing.T) {
+	before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
+	after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	require.NoError(t, err)
+	assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
+
+	b, err = proto.Marshal(nc.getRawDetails())
+	require.NoError(t, err)
+	assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
+}
+
+func TestCertificateV1_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV1(t *testing.T) {
+	// Test that we don't panic with an invalid certificate (#332)
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV1(data, nil)
+	require.EqualError(t, err, "encoded Details was nil")
+}
+
+func appendByteSlices(b ...[]byte) []byte {
+	retSlice := []byte{}
+	for _, v := range b {
+		retSlice = append(retSlice, v...)
+	}
+	return retSlice
+}
+
+func mustParsePrefixUnmapped(s string) netip.Prefix {
+	prefix := netip.MustParsePrefix(s)
+	return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
+}

+ 37 - 0
cert/cert_v2.asn1

@@ -0,0 +1,37 @@
+Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
+
+Name ::= UTF8String (SIZE (1..253))
+Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
+Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
+Curve ::= ENUMERATED {
+    curve25519 (0),
+    p256 (1)
+}
+
+-- The maximum size of a certificate must not exceed 65536 bytes
+Certificate ::= SEQUENCE {
+    details OCTET STRING,
+    curve Curve DEFAULT curve25519,
+    publicKey OCTET STRING,
+    -- signature(details + curve + publicKey) using the appropriate method for curve
+    signature OCTET STRING
+}
+
+Details ::= SEQUENCE {
+    name Name,
+
+    -- At least 1 ipv4 or ipv6 address must be present if isCA is false
+    networks SEQUENCE OF Network OPTIONAL,
+    unsafeNetworks SEQUENCE OF Network OPTIONAL,
+    groups SEQUENCE OF Name OPTIONAL,
+    isCA BOOLEAN DEFAULT false,
+    notBefore Time,
+    notAfter Time,
+
+    -- issuer is only required if isCA is false, if isCA is true then it must not be present
+    issuer OCTET STRING OPTIONAL,
+    ...
+    -- New fields can be added below here
+}
+
+END

+ 730 - 0
cert/cert_v2.go

@@ -0,0 +1,730 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net/netip"
+	"slices"
+	"time"
+
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+	"golang.org/x/crypto/curve25519"
+)
+
+const (
+	classConstructed     = 0x20
+	classContextSpecific = 0x80
+
+	TagCertDetails   = 0 | classConstructed | classContextSpecific
+	TagCertCurve     = 1 | classContextSpecific
+	TagCertPublicKey = 2 | classContextSpecific
+	TagCertSignature = 3 | classContextSpecific
+
+	TagDetailsName           = 0 | classContextSpecific
+	TagDetailsNetworks       = 1 | classConstructed | classContextSpecific
+	TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
+	TagDetailsGroups         = 3 | classConstructed | classContextSpecific
+	TagDetailsIsCA           = 4 | classContextSpecific
+	TagDetailsNotBefore      = 5 | classContextSpecific
+	TagDetailsNotAfter       = 6 | classContextSpecific
+	TagDetailsIssuer         = 7 | classContextSpecific
+)
+
+const (
+	// MaxCertificateSize is the maximum length a valid certificate can be
+	MaxCertificateSize = 65536
+
+	// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
+	MaxNameLength = 253
+
+	// MaxNetworkLength is the maximum length a network value can be.
+	// 16 bytes for an ipv6 address + 1 byte for the prefix length
+	MaxNetworkLength = 17
+)
+
+type certificateV2 struct {
+	details detailsV2
+
+	// RawDetails contains the entire asn.1 DER encoded Details struct
+	// This is to benefit forwards compatibility in signature checking.
+	// signature(RawDetails + Curve + PublicKey) == Signature
+	rawDetails []byte
+	curve      Curve
+	publicKey  []byte
+	signature  []byte
+}
+
+type detailsV2 struct {
+	name           string
+	networks       []netip.Prefix // MUST BE SORTED
+	unsafeNetworks []netip.Prefix // MUST BE SORTED
+	groups         []string
+	isCA           bool
+	notBefore      time.Time
+	notAfter       time.Time
+	issuer         string
+}
+
+func (c *certificateV2) Version() Version {
+	return Version2
+}
+
+func (c *certificateV2) Curve() Curve {
+	return c.curve
+}
+
+func (c *certificateV2) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV2) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV2) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV2) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV2) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV2) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV2) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV2) PublicKey() []byte {
+	return c.publicKey
+}
+
+func (c *certificateV2) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV2) Fingerprint() (string, error) {
+	if len(c.rawDetails) == 0 {
+		return "", ErrMissingDetails
+	}
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV2) CheckSignature(key []byte) bool {
+	if len(c.rawDetails) == 0 {
+		return false
+	}
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+
+	switch c.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV2) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.curve {
+		return ErrPublicPrivateCurveMismatch
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return ErrInvalidPrivateKey
+			}
+
+			if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return ErrInvalidPrivateKey
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.publicKey) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.publicKey) {
+		return ErrPublicPrivateKeyMismatch
+	}
+
+	return nil
+}
+
+func (c *certificateV2) String() string {
+	mb, err := c.marshalJSON()
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+
+	b, err := json.MarshalIndent(mb, "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Skipping the curve and public key since those come across in a different part of the handshake
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) Marshal() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Add the curve only if its not the default value
+		if c.curve != Curve_CURVE25519 {
+			b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
+				b.AddBytes([]byte{byte(c.curve)})
+			})
+		}
+
+		// Add the public key if it is not empty
+		if c.publicKey != nil {
+			b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
+				b.AddBytes(c.publicKey)
+			})
+		}
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
+}
+
+func (c *certificateV2) MarshalJSON() ([]byte, error) {
+	b, err := c.marshalJSON()
+	if err != nil {
+		return nil, err
+	}
+	return json.Marshal(b)
+}
+
+func (c *certificateV2) marshalJSON() (m, error) {
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, err
+	}
+
+	return m{
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+		},
+		"version":     Version2,
+		"publicKey":   fmt.Sprintf("%x", c.publicKey),
+		"curve":       c.curve.String(),
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}, nil
+}
+
+func (c *certificateV2) Copy() Certificate {
+	nc := &certificateV2{
+		details: detailsV2{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+		},
+		curve:      c.curve,
+		publicKey:  make([]byte, len(c.publicKey)),
+		signature:  make([]byte, len(c.signature)),
+		rawDetails: make([]byte, len(c.rawDetails)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.rawDetails, c.rawDetails)
+	copy(nc.signature, c.signature)
+	copy(nc.publicKey, c.publicKey)
+
+	return nc
+}
+
+func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV2{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		isCA:           t.IsCA,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		issuer:         t.issuer,
+	}
+	c.curve = t.Curve
+	c.publicKey = t.PublicKey
+	return c.validate()
+}
+
+func (c *certificateV2) validate() error {
+	// Empty names are allowed
+
+	if len(c.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
+	}
+
+	hasV4Networks := false
+	hasV6Networks := false
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+
+		if network.Addr().Is4In6() {
+			return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
+		}
+
+		hasV4Networks = hasV4Networks || network.Addr().Is4()
+		hasV6Networks = hasV6Networks || network.Addr().Is6()
+	}
+
+	slices.SortFunc(c.details.networks, comparePrefix)
+	err := findDuplicatePrefix(c.details.networks)
+	if err != nil {
+		return err
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+
+		if !c.details.isCA {
+			if network.Addr().Is6() {
+				if !hasV6Networks {
+					return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
+				}
+			} else if network.Addr().Is4() {
+				if !hasV4Networks {
+					return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
+				}
+			}
+		}
+	}
+
+	slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
+	err = findDuplicatePrefix(c.details.unsafeNetworks)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *certificateV2) marshalForSigning() ([]byte, error) {
+	d, err := c.details.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
+	}
+	c.rawDetails = d
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	return b, nil
+}
+
+func (c *certificateV2) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+func (d *detailsV2) Marshal() ([]byte, error) {
+	var b cryptobyte.Builder
+	var err error
+
+	// Details are a structure
+	b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
+
+		// Add the name
+		b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
+			b.AddBytes([]byte(d.name))
+		})
+
+		// Add the networks if any exist
+		if len(d.networks) > 0 {
+			b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.networks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add the unsafe networks if any exist
+		if len(d.unsafeNetworks) > 0 {
+			b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.unsafeNetworks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add groups if any exist
+		if len(d.groups) > 0 {
+			b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
+				for _, group := range d.groups {
+					b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
+						b.AddBytes([]byte(group))
+					})
+				}
+			})
+		}
+
+		// Add IsCA only if true
+		if d.isCA {
+			b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
+				b.AddUint8(0xff)
+			})
+		}
+
+		// Add not before
+		b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
+
+		// Add not after
+		b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
+
+		// Add the issuer if present
+		if d.issuer != "" {
+			issuerBytes, innerErr := hex.DecodeString(d.issuer)
+			if innerErr != nil {
+				err = fmt.Errorf("failed to decode issuer: %w", innerErr)
+				return
+			}
+			b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
+				b.AddBytes(issuerBytes)
+			})
+		}
+	})
+
+	if err != nil {
+		return nil, err
+	}
+
+	return b.Bytes()
+}
+
+func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
+	l := len(b)
+	if l == 0 || l > MaxCertificateSize {
+		return nil, ErrBadFormat
+	}
+
+	input := cryptobyte.String(b)
+	// Open the envelope
+	if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the cert details, we need to preserve the tag and length
+	var rawDetails cryptobyte.String
+	if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	//Maybe grab the curve
+	var rawCurve byte
+	if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
+		return nil, ErrBadFormat
+	}
+	curve = Curve(rawCurve)
+
+	// Maybe grab the public key
+	var rawPublicKey cryptobyte.String
+	if len(publicKey) > 0 {
+		rawPublicKey = publicKey
+	} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
+		return nil, ErrBadFormat
+	}
+
+	if len(rawPublicKey) == 0 {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the signature
+	var rawSignature cryptobyte.String
+	if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Finally unmarshal the details
+	details, err := unmarshalDetails(rawDetails)
+	if err != nil {
+		return nil, err
+	}
+
+	c := &certificateV2{
+		details:    details,
+		rawDetails: rawDetails,
+		curve:      curve,
+		publicKey:  rawPublicKey,
+		signature:  rawSignature,
+	}
+
+	err = c.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
+}
+
+func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
+	// Open the envelope
+	if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the name
+	var name cryptobyte.String
+	if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the network addresses
+	var subString cryptobyte.String
+	var found bool
+
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var networks []netip.Prefix
+	var val cryptobyte.String
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			networks = append(networks, n)
+		}
+	}
+
+	// Read out any unsafe networks
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var unsafeNetworks []netip.Prefix
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			unsafeNetworks = append(unsafeNetworks, n)
+		}
+	}
+
+	// Read out any groups
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var groups []string
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
+				return detailsV2{}, ErrBadFormat
+			}
+			groups = append(groups, string(val))
+		}
+	}
+
+	// Read out IsCA
+	var isCa bool
+	if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read not before and not after
+	var notBefore int64
+	if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var notAfter int64
+	if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read issuer
+	var issuer cryptobyte.String
+	if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	return detailsV2{
+		name:           string(name),
+		networks:       networks,
+		unsafeNetworks: unsafeNetworks,
+		groups:         groups,
+		isCA:           isCa,
+		notBefore:      time.Unix(notBefore, 0),
+		notAfter:       time.Unix(notAfter, 0),
+		issuer:         hex.EncodeToString(issuer),
+	}, nil
+}

+ 267 - 0
cert/cert_v2_test.go

@@ -0,0 +1,267 @@
+package cert
+
+import (
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
+	"slices"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCertificateV2_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	nc.rawDetails = db
+
+	b, err := nc.Marshal()
+	require.NoError(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
+	require.NoError(t, err)
+
+	assert.Equal(t, Version2, nc.Version())
+	assert.Equal(t, Curve_CURVE25519, nc.Curve())
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+	assert.Equal(t, nc.Issuer(), nc2.Issuer())
+
+	// unmarshalling will sort networks and unsafeNetworks, we need to do the same
+	// but first make sure it fails
+	assert.NotEqual(t, nc.Networks(), nc2.Networks())
+	assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	slices.SortFunc(nc.details.networks, comparePrefix)
+	slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV2_Expired(t *testing.T) {
+	nc := certificateV2{
+		details: detailsV2{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV2_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedf1234567890abcedf")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			isCA:      false,
+			issuer:    "1234567890abcedf1234567890abcedf",
+		},
+		publicKey: pubKey,
+		signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
+	}
+
+	b, err := nc.MarshalJSON()
+	require.ErrorIs(t, err, ErrMissingDetails)
+
+	rd, err := nc.details.Marshal()
+	require.NoError(t, err)
+
+	nc.rawDetails = rd
+	b, err = nc.MarshalJSON()
+	require.NoError(t, err)
+	assert.JSONEq(
+		t,
+		"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
+		string(b),
+	)
+}
+
+func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	require.NoError(t, err)
+
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	_, caKey2, err := ed25519.GenerateKey(rand.Reader)
+	require.NoError(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	require.NoError(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	ac, ok := c.(*certificateV2)
+	require.True(t, ok)
+	ac.curve = Curve(99)
+	err = c.VerifyPrivateKey(Curve(99), priv2)
+	require.EqualError(t, err, "invalid curve: 99")
+
+	ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	require.NoError(t, err)
+
+	err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv[:16])
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv)
+	require.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	aCa, ok := ca2.(*certificateV2)
+	require.True(t, ok)
+	aCa.curve = Curve(99)
+	err = aCa.VerifyPrivateKey(Curve(99), priv2)
+	require.EqualError(t, err, "invalid curve: 99")
+
+}
+
+func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	require.NoError(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	require.NoError(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	require.Error(t, err)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	require.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	require.NoError(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	require.Error(t, err)
+}
+
+func TestCertificateV2_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV2(t *testing.T) {
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
+	require.EqualError(t, err, "bad wire format")
+}
+
+func TestCertificateV2_marshalForSigningStability(t *testing.T) {
+	before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
+	after := before.Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
+	expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
+	require.NoError(t, err)
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	assert.Equal(t, expectedRawDetails, db)
+
+	expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
+	b, err := nc.marshalForSigning()
+	require.NoError(t, err)
+	assert.Equal(t, expectedForSigning, b)
+}

+ 159 - 2
cert/crypto.go

@@ -3,14 +3,28 @@ package cert
 import (
 	"crypto/aes"
 	"crypto/cipher"
+	"crypto/ed25519"
 	"crypto/rand"
+	"encoding/pem"
 	"fmt"
 	"io"
+	"math"
 
 	"golang.org/x/crypto/argon2"
+	"google.golang.org/protobuf/proto"
 )
 
-// KDF factors
+type NebulaEncryptedData struct {
+	EncryptionMetadata NebulaEncryptionMetadata
+	Ciphertext         []byte
+}
+
+type NebulaEncryptionMetadata struct {
+	EncryptionAlgorithm string
+	Argon2Parameters    Argon2Parameters
+}
+
+// Argon2Parameters KDF factors
 type Argon2Parameters struct {
 	version     rune
 	Memory      uint32 // KiB
@@ -19,7 +33,7 @@ type Argon2Parameters struct {
 	salt        []byte
 }
 
-// Returns a new Argon2Parameters object with current version set
+// NewArgon2Parameters Returns a new Argon2Parameters object with current version set
 func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
 	return &Argon2Parameters{
 		version:     argon2.Version,
@@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) {
 
 	return blob[:nonceSize], blob[nonceSize:], nil
 }
+
+// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
+func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
+	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
+	if err != nil {
+		return nil, err
+	}
+
+	b, err = proto.Marshal(&RawNebulaEncryptedData{
+		EncryptionMetadata: &RawNebulaEncryptionMetadata{
+			EncryptionAlgorithm: "AES-256-GCM",
+			Argon2Parameters: &RawNebulaArgon2Parameters{
+				Version:     kdfParams.version,
+				Memory:      kdfParams.Memory,
+				Parallelism: uint32(kdfParams.Parallelism),
+				Iterations:  kdfParams.Iterations,
+				Salt:        kdfParams.salt,
+			},
+		},
+		Ciphertext: ciphertext,
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
+	default:
+		return nil, fmt.Errorf("invalid curve: %v", curve)
+	}
+}
+
+// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
+// protobuf-generated struct.
+func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rned RawNebulaEncryptedData
+	err := proto.Unmarshal(b, &rned)
+	if err != nil {
+		return nil, err
+	}
+
+	if rned.EncryptionMetadata == nil {
+		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
+	}
+
+	if rned.EncryptionMetadata.Argon2Parameters == nil {
+		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
+	}
+
+	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
+	if err != nil {
+		return nil, err
+	}
+
+	ned := NebulaEncryptedData{
+		EncryptionMetadata: NebulaEncryptionMetadata{
+			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
+			Argon2Parameters:    *params,
+		},
+		Ciphertext: rned.Ciphertext,
+	}
+
+	return &ned, nil
+}
+
+func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
+	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
+		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
+	}
+	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
+		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
+	}
+	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
+		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
+	}
+	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
+		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
+	}
+
+	return &Argon2Parameters{
+		version:     params.Version,
+		Memory:      params.Memory,
+		Parallelism: uint8(params.Parallelism),
+		Iterations:  params.Iterations,
+		salt:        params.Salt,
+	}, nil
+
+}
+
+// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
+// the given passphrase, returning any other bytes b or an error on failure
+func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
+	var curve Curve
+
+	k, r := pem.Decode(b)
+	if k == nil {
+		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+	case EncryptedECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+	default:
+		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	}
+
+	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
+	if err != nil {
+		return curve, nil, r, err
+	}
+
+	var bytes []byte
+	switch ned.EncryptionMetadata.EncryptionAlgorithm {
+	case "AES-256-GCM":
+		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
+		if err != nil {
+			return curve, nil, r, err
+		}
+	default:
+		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
+	}
+
+	switch curve {
+	case Curve_CURVE25519:
+		if len(bytes) != ed25519.PrivateKeySize {
+			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case Curve_P256:
+		if len(bytes) != 32 {
+			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	}
+
+	return curve, bytes, r, nil
+}

+ 90 - 2
cert/crypto_test.go

@@ -4,22 +4,110 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/crypto/argon2"
 )
 
 func TestNewArgon2Parameters(t *testing.T) {
 	p := NewArgon2Parameters(64*1024, 4, 3)
-	assert.EqualValues(t, &Argon2Parameters{
+	assert.Equal(t, &Argon2Parameters{
 		version:     argon2.Version,
 		Memory:      64 * 1024,
 		Parallelism: 4,
 		Iterations:  3,
 	}, p)
 	p = NewArgon2Parameters(2*1024*1024, 2, 1)
-	assert.EqualValues(t, &Argon2Parameters{
+	assert.Equal(t, &Argon2Parameters{
 		version:     argon2.Version,
 		Memory:      2 * 1024 * 1024,
 		Parallelism: 2,
 		Iterations:  1,
 	}, p)
 }
+
+func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
+	passphrase := []byte("DO NOT USE THIS KEY")
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A key which, once decrypted, is too short
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
+k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
+GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
+rQr3bdH3Oy/WiYU=
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner (not encrypted)
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
+XgLvodMXZJuaFPssp+WwtA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+
+	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
+	require.NoError(t, err)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+
+	// Fail due to invalid banner
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to invalid passphrase
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
+	require.EqualError(t, err, "invalid passphrase or corrupt private key")
+	assert.Nil(t, k)
+	assert.Equal(t, []byte{}, rest)
+}
+
+func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
+	// Having proved that decryption works correctly above, we can test the
+	// encryption function produces a value which can be decrypted
+	passphrase := []byte("passphrase")
+	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
+	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
+	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
+	require.NoError(t, err)
+
+	// Verify the "key" can be decrypted successfully
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
+	assert.Len(t, k, 64)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, []byte{}, rest)
+	require.NoError(t, err)
+
+	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
+}

+ 41 - 6
cert/errors.go

@@ -2,13 +2,48 @@ package cert
 
 import (
 	"errors"
+	"fmt"
 )
 
 var (
-	ErrRootExpired       = errors.New("root certificate is expired")
-	ErrExpired           = errors.New("certificate is expired")
-	ErrNotCA             = errors.New("certificate is not a CA")
-	ErrNotSelfSigned     = errors.New("certificate is not self-signed")
-	ErrBlockListed       = errors.New("certificate is in the block list")
-	ErrSignatureMismatch = errors.New("certificate signature did not match")
+	ErrBadFormat                  = errors.New("bad wire format")
+	ErrRootExpired                = errors.New("root certificate is expired")
+	ErrExpired                    = errors.New("certificate is expired")
+	ErrNotCA                      = errors.New("certificate is not a CA")
+	ErrNotSelfSigned              = errors.New("certificate is not self-signed")
+	ErrBlockListed                = errors.New("certificate is in the block list")
+	ErrFingerprintMismatch        = errors.New("certificate fingerprint did not match")
+	ErrSignatureMismatch          = errors.New("certificate signature did not match")
+	ErrInvalidPublicKey           = errors.New("invalid public key")
+	ErrInvalidPrivateKey          = errors.New("invalid private key")
+	ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
+	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
+	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
+	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
+
+	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
+	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")
+	ErrInvalidPEMX25519PublicKeyBanner   = errors.New("bytes did not contain a proper X25519 public key banner")
+	ErrInvalidPEMX25519PrivateKeyBanner  = errors.New("bytes did not contain a proper X25519 private key banner")
+	ErrInvalidPEMEd25519PublicKeyBanner  = errors.New("bytes did not contain a proper Ed25519 public key banner")
+	ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
+
+	ErrNoPeerStaticKey = errors.New("no peer static key was present")
+	ErrNoPayload       = errors.New("provided payload was empty")
+
+	ErrMissingDetails  = errors.New("certificate did not contain details")
+	ErrEmptySignature  = errors.New("empty signature")
+	ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
 )
+
+type ErrInvalidCertificateProperties struct {
+	str string
+}
+
+func NewErrInvalidCertificateProperties(format string, a ...any) error {
+	return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
+}
+
+func (e *ErrInvalidCertificateProperties) Error() string {
+	return e.str
+}

+ 141 - 0
cert/helper_test.go

@@ -0,0 +1,141 @@
+package cert
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 161 - 0
cert/pem.go

@@ -0,0 +1,161 @@
+package cert
+
+import (
+	"encoding/pem"
+	"fmt"
+
+	"golang.org/x/crypto/ed25519"
+)
+
+const (
+	CertificateBanner                = "NEBULA CERTIFICATE"
+	CertificateV2Banner              = "NEBULA CERTIFICATE V2"
+	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
+	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
+	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
+	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
+	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
+
+	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
+	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
+	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
+	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+)
+
+// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
+// data or an error on failure
+func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
+	p, r := pem.Decode(b)
+	if p == nil {
+		return nil, r, ErrInvalidPEMBlock
+	}
+
+	var c Certificate
+	var err error
+
+	switch p.Type {
+	// Implementations must validate the resulting certificate contains valid information
+	case CertificateBanner:
+		c, err = unmarshalCertificateV1(p.Bytes, nil)
+	case CertificateV2Banner:
+		c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
+	default:
+		return nil, r, ErrInvalidPEMCertificateBanner
+	}
+
+	if err != nil {
+		return nil, r, err
+	}
+
+	return c, r, nil
+
+}
+
+func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PublicKeyBanner:
+		// Uncompressed
+		expectedLen = 65
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
+// consumed data or an error on failure
+func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var curve Curve
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
+	case EncryptedECDSAP256PrivateKeyBanner:
+		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
+	case Ed25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+		if len(k.Bytes) != ed25519.PrivateKeySize {
+			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case ECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+		if len(k.Bytes) != 32 {
+			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner")
+	}
+	return k.Bytes, r, curve, nil
+}

+ 293 - 0
cert/pem_test.go

@@ -0,0 +1,293 @@
+package cert
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestUnmarshalCertificateFromPEM(t *testing.T) {
+	goodCert := []byte(`
+# A good cert
+-----BEGIN NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-----END NEBULA CERTIFICATE-----
+`)
+	badBanner := []byte(`# A bad banner
+-----BEGIN NOT A NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-----END NOT A NEBULA CERTIFICATE-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-END NEBULA CERTIFICATE----`)
+
+	certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
+
+	// Success test case
+	cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
+	assert.NotNil(t, cert)
+	assert.Equal(t, rest, append(badBanner, invalidPem...))
+	require.NoError(t, err)
+
+	// Fail due to invalid banner.
+	cert, rest, err = UnmarshalCertificateFromPEM(rest)
+	assert.Nil(t, cert)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "bytes did not contain a proper certificate banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	cert, rest, err = UnmarshalCertificateFromPEM(rest)
+	assert.Nil(t, cert)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ECDSA P256 PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NOT A NEBULA PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-END NEBULA ED25519 PRIVATE KEY-----`)
+
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	require.NoError(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+	require.NoError(t, err)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA X25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA X25519 PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA X25519 PRIVATE KEY-----`)
+
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	require.NoError(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+	require.NoError(t, err)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "bytes did not contain a proper private key banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
+	pubKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ED25519 PUBLIC KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA ED25519 PUBLIC KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PUBLIC KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA ED25519 PUBLIC KEY-----`)
+
+	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
+	assert.Len(t, k, 32)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	require.NoError(t, err)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	require.EqualError(t, err, "bytes did not contain a proper public key banner")
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalX25519PublicKey(t *testing.T) {
+	pubKey := []byte(`# A good key
+-----BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA X25519 PUBLIC KEY-----
+`)
+	pubP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PUBLIC KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA X25519 PUBLIC KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PUBLIC KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA X25519 PUBLIC KEY-----`)
+
+	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
+	assert.Len(t, k, 32)
+	require.NoError(t, err)
+	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Len(t, k, 65)
+	require.NoError(t, err)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	require.EqualError(t, err, "bytes did not contain a proper public key banner")
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	require.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}

+ 167 - 0
cert/sign.go

@@ -0,0 +1,167 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/sha256"
+	"fmt"
+	"math/big"
+	"net/netip"
+	"time"
+)
+
+// TBSCertificate represents a certificate intended to be signed.
+// It is invalid to use this structure as a Certificate.
+type TBSCertificate struct {
+	Version        Version
+	Name           string
+	Networks       []netip.Prefix
+	UnsafeNetworks []netip.Prefix
+	Groups         []string
+	IsCA           bool
+	NotBefore      time.Time
+	NotAfter       time.Time
+	PublicKey      []byte
+	Curve          Curve
+	issuer         string
+}
+
+type beingSignedCertificate interface {
+	// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
+	// Implementations must validate the resulting certificate contains valid information
+	fromTBSCertificate(*TBSCertificate) error
+
+	// marshalForSigning returns the bytes that should be signed
+	marshalForSigning() ([]byte, error)
+
+	// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
+	setSignature([]byte) error
+}
+
+type SignerLambda func(certBytes []byte) ([]byte, error)
+
+// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
+// details do not violate constraints of the signing certificate.
+// If the TBSCertificate is a CA then signer must be nil.
+func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
+	switch t.Curve {
+	case Curve_CURVE25519:
+		pk := ed25519.PrivateKey(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			sig := ed25519.Sign(pk, certBytes)
+			return sig, nil
+		}
+		return t.SignWith(signer, curve, sp)
+	case Curve_P256:
+		pk := &ecdsa.PrivateKey{
+			PublicKey: ecdsa.PublicKey{
+				Curve: elliptic.P256(),
+			},
+			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
+			D: new(big.Int).SetBytes(key),
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
+		pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			// We need to hash first for ECDSA
+			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
+			hashed := sha256.Sum256(certBytes)
+			return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
+		}
+		return t.SignWith(signer, curve, sp)
+	default:
+		return nil, fmt.Errorf("invalid curve: %s", t.Curve)
+	}
+}
+
+// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
+// You should only use SignWith if you do not have direct access to your private key.
+func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
+	if curve != t.Curve {
+		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+
+	if signer != nil {
+		if t.IsCA {
+			return nil, fmt.Errorf("can not sign a CA certificate with another")
+		}
+
+		err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks)
+		if err != nil {
+			return nil, err
+		}
+
+		issuer, err := signer.Fingerprint()
+		if err != nil {
+			return nil, fmt.Errorf("error computing issuer: %v", err)
+		}
+		t.issuer = issuer
+	} else {
+		if !t.IsCA {
+			return nil, fmt.Errorf("self signed certificates must have IsCA set to true")
+		}
+	}
+
+	var c beingSignedCertificate
+	switch t.Version {
+	case Version1:
+		c = &certificateV1{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	case Version2:
+		c = &certificateV2{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	default:
+		return nil, fmt.Errorf("unknown cert version %d", t.Version)
+	}
+
+	certBytes, err := c.marshalForSigning()
+	if err != nil {
+		return nil, err
+	}
+
+	sig, err := sp(certBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	err = c.setSignature(sig)
+	if err != nil {
+		return nil, err
+	}
+
+	sc, ok := c.(Certificate)
+	if !ok {
+		return nil, fmt.Errorf("invalid certificate")
+	}
+
+	return sc, nil
+}
+
+func comparePrefix(a, b netip.Prefix) int {
+	addr := a.Addr().Compare(b.Addr())
+	if addr == 0 {
+		return a.Bits() - b.Bits()
+	}
+	return addr
+}
+
+// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
+func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
+	if len(sortedPrefixes) < 2 {
+		return nil
+	}
+	for i := 1; i < len(sortedPrefixes); i++ {
+		if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
+			return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
+		}
+	}
+	return nil
+}

+ 91 - 0
cert/sign_test.go

@@ -0,0 +1,91 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCertificateV1_Sign(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/24"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+	}
+
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
+	require.NoError(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	require.NoError(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	require.NoError(t, err)
+	assert.NotNil(t, uc)
+}
+
+func TestCertificateV1_SignP256(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/16"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+		Curve:     Curve_P256,
+	}
+
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	require.NoError(t, err)
+	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
+	rawPriv := priv.D.FillBytes(make([]byte, 32))
+
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
+	require.NoError(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	require.NoError(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	require.NoError(t, err)
+	assert.NotNil(t, uc)
+}

+ 138 - 0
cert_test/cert.go

@@ -0,0 +1,138 @@
+package cert_test
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case cert.Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &cert.TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case cert.Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &cert.TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 0 - 10
cidr/parse.go

@@ -1,10 +0,0 @@
-package cidr
-
-import "net"
-
-// Parse is a convenience function that returns only the IPNet
-// This function ignores errors since it is primarily a test helper, the result could be nil
-func Parse(s string) *net.IPNet {
-	_, c, _ := net.ParseCIDR(s)
-	return c
-}

+ 0 - 203
cidr/tree4.go

@@ -1,203 +0,0 @@
-package cidr
-
-import (
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-type Node[T any] struct {
-	left     *Node[T]
-	right    *Node[T]
-	parent   *Node[T]
-	hasValue bool
-	value    T
-}
-
-type entry[T any] struct {
-	CIDR  *net.IPNet
-	Value T
-}
-
-type Tree4[T any] struct {
-	root *Node[T]
-	list []entry[T]
-}
-
-const (
-	startbit = iputil.VpnIp(0x80000000)
-)
-
-func NewTree4[T any]() *Tree4[T] {
-	tree := new(Tree4[T])
-	tree.root = &Node[T]{}
-	tree.list = []entry[T]{}
-	return tree
-}
-
-func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
-	bit := startbit
-	node := tree.root
-	next := tree.root
-
-	ip := iputil.Ip2VpnIp(cidr.IP)
-	mask := iputil.Ip2VpnIp(cidr.Mask)
-
-	// Find our last ancestor in the tree
-	for bit&mask != 0 {
-		if ip&bit != 0 {
-			next = node.right
-		} else {
-			next = node.left
-		}
-
-		if next == nil {
-			break
-		}
-
-		bit = bit >> 1
-		node = next
-	}
-
-	// We already have this range so update the value
-	if next != nil {
-		addCIDR := cidr.String()
-		for i, v := range tree.list {
-			if addCIDR == v.CIDR.String() {
-				tree.list = append(tree.list[:i], tree.list[i+1:]...)
-				break
-			}
-		}
-
-		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-		node.value = val
-		node.hasValue = true
-		return
-	}
-
-	// Build up the rest of the tree we don't already have
-	for bit&mask != 0 {
-		next = &Node[T]{}
-		next.parent = node
-
-		if ip&bit != 0 {
-			node.right = next
-		} else {
-			node.left = next
-		}
-
-		bit >>= 1
-		node = next
-	}
-
-	// Final node marks our cidr, set the value
-	node.value = val
-	node.hasValue = true
-	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-}
-
-// Contains finds the first match, which may be the least specific
-func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			return true, node.value
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-
-	}
-
-	return false, value
-}
-
-// MostSpecificContains finds the most specific match
-func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			value = node.value
-			ok = true
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return ok, value
-}
-
-type eachFunc[T any] func(T) bool
-
-// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
-// The final return value will be true if the provided function returned true
-func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			// If the each func returns true then we can exit the loop
-			if each(node.value) {
-				return true
-			}
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return false
-}
-
-// GetCIDR returns the entry added by the most recent matching AddCIDR call
-func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	ip := iputil.Ip2VpnIp(cidr.IP)
-	mask := iputil.Ip2VpnIp(cidr.Mask)
-
-	// Find our last ancestor in the tree
-	for node != nil && bit&mask != 0 {
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit = bit >> 1
-	}
-
-	if bit&mask == 0 && node != nil {
-		value = node.value
-		ok = node.hasValue
-	}
-
-	return ok, value
-}
-
-// List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4[T]) List() []entry[T] {
-	return tree.list
-}

+ 0 - 170
cidr/tree4_test.go

@@ -1,170 +0,0 @@
-package cidr
-
-import (
-	"net"
-	"testing"
-
-	"github.com/slackhq/nebula/iputil"
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_List(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
-	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
-	tree.AddCIDR(Parse("1.0.0.0/16"), "4")
-	list := tree.List()
-	assert.Len(t, list, 2)
-	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
-	assert.Equal(t, "2", list[0].Value)
-	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
-	assert.Equal(t, "4", list[1].Value)
-}
-
-func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
-	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4a", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4b", "4.1.1.2"},
-		{true, "4c", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-}
-
-func TestTree4_GetCIDR(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
-	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IPNet  *net.IPNet
-	}{
-		{true, "1", Parse("1.0.0.0/8")},
-		{true, "2", Parse("2.1.0.0/16")},
-		{true, "3", Parse("3.1.1.0/24")},
-		{true, "4a", Parse("4.1.1.0/24")},
-		{true, "4b", Parse("4.1.1.1/32")},
-		{true, "4c", Parse("4.1.2.1/32")},
-		{true, "5", Parse("254.0.0.0/4")},
-		{false, "", Parse("2.0.0.0/8")},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.GetCIDR(tt.IPNet)
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
-
-	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-}

+ 0 - 189
cidr/tree6.go

@@ -1,189 +0,0 @@
-package cidr
-
-import (
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-const startbit6 = uint64(1 << 63)
-
-type Tree6[T any] struct {
-	root4 *Node[T]
-	root6 *Node[T]
-}
-
-func NewTree6[T any]() *Tree6[T] {
-	tree := new(Tree6[T])
-	tree.root4 = &Node[T]{}
-	tree.root6 = &Node[T]{}
-	return tree
-}
-
-func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
-	var node, next *Node[T]
-
-	cidrIP, ipv4 := isIPV4(cidr.IP)
-	if ipv4 {
-		node = tree.root4
-		next = tree.root4
-
-	} else {
-		node = tree.root6
-		next = tree.root6
-	}
-
-	for i := 0; i < len(cidrIP); i += 4 {
-		ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
-		mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
-		bit := startbit
-
-		// Find our last ancestor in the tree
-		for bit&mask != 0 {
-			if ip&bit != 0 {
-				next = node.right
-			} else {
-				next = node.left
-			}
-
-			if next == nil {
-				break
-			}
-
-			bit = bit >> 1
-			node = next
-		}
-
-		// Build up the rest of the tree we don't already have
-		for bit&mask != 0 {
-			next = &Node[T]{}
-			next.parent = node
-
-			if ip&bit != 0 {
-				node.right = next
-			} else {
-				node.left = next
-			}
-
-			bit >>= 1
-			node = next
-		}
-	}
-
-	// Final node marks our cidr, set the value
-	node.value = val
-	node.hasValue = true
-}
-
-// Finds the most specific match
-func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
-	var node *Node[T]
-
-	wholeIP, ipv4 := isIPV4(ip)
-	if ipv4 {
-		node = tree.root4
-	} else {
-		node = tree.root6
-	}
-
-	for i := 0; i < len(wholeIP); i += 4 {
-		ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
-		bit := startbit
-
-		for node != nil {
-			if node.hasValue {
-				value = node.value
-				ok = true
-			}
-
-			if bit == 0 {
-				break
-			}
-
-			if ip&bit != 0 {
-				node = node.right
-			} else {
-				node = node.left
-			}
-
-			bit >>= 1
-		}
-	}
-
-	return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root4
-
-	for node != nil {
-		if node.hasValue {
-			value = node.value
-			ok = true
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
-	ip := hi
-	node := tree.root6
-
-	for i := 0; i < 2; i++ {
-		bit := startbit6
-
-		for node != nil {
-			if node.hasValue {
-				value = node.value
-				ok = true
-			}
-
-			if bit == 0 {
-				break
-			}
-
-			if ip&bit != 0 {
-				node = node.right
-			} else {
-				node = node.left
-			}
-
-			bit >>= 1
-		}
-
-		ip = lo
-	}
-
-	return ok, value
-}
-
-func isIPV4(ip net.IP) (net.IP, bool) {
-	if len(ip) == net.IPv4len {
-		return ip, true
-	}
-
-	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
-		return ip[12:16], true
-	}
-
-	return ip, false
-}
-
-func isZeros(p net.IP) bool {
-	for i := 0; i < len(p); i++ {
-		if p[i] != 0 {
-			return false
-		}
-	}
-	return true
-}

+ 0 - 98
cidr/tree6_test.go

@@ -1,98 +0,0 @@
-package cidr
-
-import (
-	"encoding/binary"
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewTree6[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4b", "4.1.1.2"},
-		{true, "4c", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{true, "6a", "1:2:0:4:1:1:1:1"},
-		{true, "6b", "1:2:0:4:5:1:1:1"},
-		{true, "6c", "1:2:0:4:5:0:0:0"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree6[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	tree.AddCIDR(Parse("::/0"), "cool6")
-	ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("::"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool6", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool6", r)
-}
-
-func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewTree6[string]()
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "6a", "1:2:0:4:1:1:1:1"},
-		{true, "6b", "1:2:0:4:5:1:1:1"},
-		{true, "6c", "1:2:0:4:5:0:0:0"},
-	}
-
-	for _, tt := range tests {
-		ip := net.ParseIP(tt.IP)
-		hi := binary.BigEndian.Uint64(ip[:8])
-		lo := binary.BigEndian.Uint64(ip[8:])
-
-		ok, r := tree.MostSpecificContainsIpV6(hi, lo)
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-}

+ 137 - 73
cmd/nebula-cert/ca.go

@@ -8,13 +8,14 @@ import (
 	"fmt"
 	"io"
 	"math"
-	"net"
+	"net/netip"
 	"os"
 	"strings"
 	"time"
 
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/ed25519"
 )
 
@@ -26,32 +27,43 @@ type caFlags struct {
 	outCertPath      *string
 	outQRPath        *string
 	groups           *string
-	ips              *string
-	subnets          *string
+	networks         *string
+	unsafeNetworks   *string
 	argonMemory      *uint
 	argonIterations  *uint
 	argonParallelism *uint
 	encryption       *bool
+	version          *uint
 
-	curve *string
+	curve  *string
+	p11url *string
+
+	// Deprecated options
+	ips     *string
+	subnets *string
 }
 
 func newCaFlags() *caFlags {
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf.set.Usage = func() {}
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
+	cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
-	cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
-	cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
+	cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
+	cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
+
+	cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
+	cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &cf
 }
 
@@ -76,17 +88,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return err
 	}
 
+	isP11 := len(*cf.p11url) > 0
+
 	if err := mustFlagString("name", cf.name); err != nil {
 		return err
 	}
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
 	if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
 		return err
 	}
 	var kdfParams *cert.Argon2Parameters
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil {
 			return err
 		}
@@ -106,44 +122,57 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		}
 	}
 
-	var ips []*net.IPNet
-	if *cf.ips != "" {
-		for _, rs := range strings.Split(*cf.ips, ",") {
+	version := cert.Version(*cf.version)
+	if version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
+	var networks []netip.Prefix
+	if *cf.networks == "" && *cf.ips != "" {
+		// Pull up deprecated -ips flag if needed
+		*cf.networks = *cf.ips
+	}
+
+	if *cf.networks != "" {
+		for _, rs := range strings.Split(*cf.networks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				ip, ipNet, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid ip definition: %s", err)
+					return newHelpErrorf("invalid -networks definition: %s", rs)
 				}
-				if ip.To4() == nil {
-					return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
-
-				ipNet.IP = ip
-				ips = append(ips, ipNet)
+				networks = append(networks, n)
 			}
 		}
 	}
 
-	var subnets []*net.IPNet
-	if *cf.subnets != "" {
-		for _, rs := range strings.Split(*cf.subnets, ",") {
+	var unsafeNetworks []netip.Prefix
+	if *cf.unsafeNetworks == "" && *cf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*cf.unsafeNetworks = *cf.subnets
+	}
+
+	if *cf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				_, s, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
-				if s.IP.To4() == nil {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
-				subnets = append(subnets, s)
+				unsafeNetworks = append(unsafeNetworks, n)
 			}
 		}
 	}
 
 	var passphrase []byte
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		for i := 0; i < 5; i++ {
 			out.Write([]byte("Enter passphrase: "))
 			passphrase, err = pr.ReadPassword()
@@ -166,74 +195,109 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 
 	var curve cert.Curve
 	var pub, rawPriv []byte
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		curve = cert.Curve_CURVE25519
-		pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+
+		p11Client, err = pkclient.FromUrl(*cf.p11url)
 		if err != nil {
-			return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
 		}
-	case "P256":
-		var key *ecdsa.PrivateKey
-		curve = cert.Curve_P256
-		key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
 		if err != nil {
-			return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
 		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			curve = cert.Curve_CURVE25519
+			pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			}
+		case "P256":
+			var key *ecdsa.PrivateKey
+			curve = cert.Curve_P256
+			key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			}
 
-		// ecdh.PrivateKey lets us get at the encoded bytes, even though
-		// we aren't using ECDH here.
-		eKey, err := key.ECDH()
-		if err != nil {
-			return fmt.Errorf("error while converting ecdsa key: %s", err)
+			// ecdh.PrivateKey lets us get at the encoded bytes, even though
+			// we aren't using ECDH here.
+			eKey, err := key.ECDH()
+			if err != nil {
+				return fmt.Errorf("error while converting ecdsa key: %s", err)
+			}
+			rawPriv = eKey.Bytes()
+			pub = eKey.PublicKey().Bytes()
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
 		}
-		rawPriv = eKey.Bytes()
-		pub = eKey.PublicKey().Bytes()
 	}
 
-	nc := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      *cf.name,
-			Groups:    groups,
-			Ips:       ips,
-			Subnets:   subnets,
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(*cf.duration),
-			PublicKey: pub,
-			IsCA:      true,
-			Curve:     curve,
-		},
+	t := &cert.TBSCertificate{
+		Version:        version,
+		Name:           *cf.name,
+		Groups:         groups,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		NotBefore:      time.Now(),
+		NotAfter:       time.Now().Add(*cf.duration),
+		PublicKey:      pub,
+		IsCA:           true,
+		Curve:          curve,
 	}
 
-	if _, err := os.Stat(*cf.outKeyPath); err == nil {
-		return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+	if !isP11 {
+		if _, err := os.Stat(*cf.outKeyPath); err == nil {
+			return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+		}
 	}
 
 	if _, err := os.Stat(*cf.outCertPath); err == nil {
 		return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
 	}
 
-	err = nc.Sign(curve, rawPriv)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
-	}
-
+	var c cert.Certificate
 	var b []byte
-	if *cf.encryption {
-		b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+
+	if isP11 {
+		c, err = t.SignWith(nil, curve, p11Client.SignASN1)
 		if err != nil {
-			return fmt.Errorf("error while encrypting out-key: %s", err)
+			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 		}
 	} else {
-		b = cert.MarshalSigningPrivateKey(curve, rawPriv)
-	}
+		c, err = t.Sign(nil, curve, rawPriv)
+		if err != nil {
+			return fmt.Errorf("error while signing: %s", err)
+		}
 
-	err = os.WriteFile(*cf.outKeyPath, b, 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+		if *cf.encryption {
+			b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+			if err != nil {
+				return fmt.Errorf("error while encrypting out-key: %s", err)
+			}
+		} else {
+			b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
+		}
+
+		err = os.WriteFile(*cf.outKeyPath, b, 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
 
-	b, err = nc.MarshalToPEM()
+	b, err = c.MarshalPEM()
 	if err != nil {
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}

+ 75 - 68
cmd/nebula-cert/ca_test.go

@@ -14,10 +14,9 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
-//TODO: test file permissions
-
 func Test_caSummary(t *testing.T) {
 	assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
 }
@@ -43,17 +42,24 @@ func Test_caHelp(t *testing.T) {
 			"  -groups string\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"  -ips string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
+			"    	Deprecated, see -networks\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the certificate authority\n"+
+			"  -networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"  -out-key string\n"+
 			"    \tOptional: path to write the private key to (default \"ca.key\")\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use (default 2)\n",
 		ob.String(),
 	)
 }
@@ -82,93 +88,94 @@ func Test_ca(t *testing.T) {
 
 	// required args
 	assertHelpError(t, ca(
-		[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
+		[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
 	), "-name is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// ipv4 only ips
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// ipv4 only subnets
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
-	args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
+	require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
-	assert.Nil(t, err)
-	os.Remove(keyF.Name())
+	require.NoError(t, err)
+	require.NoError(t, os.Remove(keyF.Name()))
 
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
+	require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
-	assert.Nil(t, err)
-	os.Remove(crtF.Name())
-	os.Remove(keyF.Name())
+	require.NoError(t, err)
+	require.NoError(t, os.Remove(crtF.Name()))
+	require.NoError(t, os.Remove(keyF.Name()))
 
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Nil(t, ca(args, ob, eb, nopw))
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.NoError(t, ca(args, ob, eb, nopw))
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
+	lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, c)
+	assert.Empty(t, b)
+	require.NoError(t, err)
 	assert.Len(t, lKey, 64)
 
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
-
-	assert.Equal(t, "test", lCrt.Details.Name)
-	assert.Len(t, lCrt.Details.Ips, 0)
-	assert.True(t, lCrt.Details.IsCA)
-	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
-	assert.Len(t, lCrt.Details.Subnets, 0)
-	assert.Len(t, lCrt.Details.PublicKey, 32)
-	assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
-	assert.Equal(t, "", lCrt.Details.Issuer)
-	assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
+	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
+	assert.Empty(t, b)
+	require.NoError(t, err)
+
+	assert.Equal(t, "test", lCrt.Name())
+	assert.Empty(t, lCrt.Networks())
+	assert.True(t, lCrt.IsCA())
+	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
+	assert.Empty(t, lCrt.UnsafeNetworks())
+	assert.Len(t, lCrt.PublicKey(), 32)
+	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
+	assert.Empty(t, lCrt.Issuer())
+	assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
 
 	// test encrypted key
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Nil(t, ca(args, ob, eb, testpw))
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.NoError(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// read encrypted key file and verify default params
 	rb, _ = os.ReadFile(keyF.Name())
 	k, _ := pem.Decode(rb)
 	ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	// we won't know salt in advance, so just check start of string
 	assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
 	assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
@@ -178,8 +185,8 @@ func Test_ca(t *testing.T) {
 	var curve cert.Curve
 	curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
 	assert.Equal(t, cert.Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-	assert.Len(t, b, 0)
+	require.NoError(t, err)
+	assert.Empty(t, b)
 	assert.Len(t, lKey, 64)
 
 	// test when reading passsword results in an error
@@ -187,45 +194,45 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Error(t, ca(args, ob, eb, errpw))
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// test when user fails to enter a password
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.Nil(t, ca(args, ob, eb, nopw))
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.NoError(t, ca(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// test that we won't overwrite existing key file
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
-	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	os.Remove(keyF.Name())
 
 }

+ 49 - 20
cmd/nebula-cert/keygen.go

@@ -6,6 +6,8 @@ import (
 	"io"
 	"os"
 
+	"github.com/slackhq/nebula/pkclient"
+
 	"github.com/slackhq/nebula/cert"
 )
 
@@ -13,8 +15,8 @@ type keygenFlags struct {
 	set        *flag.FlagSet
 	outKeyPath *string
 	outPubPath *string
-
-	curve *string
+	curve      *string
+	p11url     *string
 }
 
 func newKeygenFlags() *keygenFlags {
@@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags {
 	cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
 	cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
 	cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
 	return &cf
 }
 
@@ -33,32 +36,58 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	isP11 := len(*cf.p11url) > 0
+
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
-	if err := mustFlagString("out-pub", cf.outPubPath); err != nil {
+	if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
 		return err
 	}
 
 	var pub, rawPriv []byte
 	var curve cert.Curve
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		pub, rawPriv = x25519Keypair()
-		curve = cert.Curve_CURVE25519
-	case "P256":
-		pub, rawPriv = p256Keypair()
-		curve = cert.Curve_P256
-	default:
-		return fmt.Errorf("invalid curve: %s", *cf.curve)
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			pub, rawPriv = x25519Keypair()
+			curve = cert.Curve_CURVE25519
+		case "P256":
+			pub, rawPriv = p256Keypair()
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
+		}
 	}
 
-	err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+	if isP11 {
+		p11Client, err := pkclient.FromUrl(*cf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key: %w", err)
+		}
+	} else {
+		err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
-
-	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
+	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
 	}
@@ -72,7 +101,7 @@ func keygenSummary() string {
 
 func keygenHelp(out io.Writer) {
 	cf := newKeygenFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
 	cf.set.SetOutput(out)
 	cf.set.PrintDefaults()
 }

+ 26 - 24
cmd/nebula-cert/keygen_test.go

@@ -7,10 +7,9 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
-//TODO: test file permissions
-
 func Test_keygenSummary(t *testing.T) {
 	assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
 }
@@ -26,7 +25,8 @@ func Test_keygenHelp(t *testing.T) {
 			"  -out-key string\n"+
 			"    \tRequired: path to write the private key to\n"+
 			"  -out-pub string\n"+
-			"    \tRequired: path to write the public key to\n",
+			"    \tRequired: path to write the public key to\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n"),
 		ob.String(),
 	)
 }
@@ -37,57 +37,59 @@ func Test_keygen(t *testing.T) {
 
 	// required args
 	assertHelpError(t, keygen([]string{"-out-pub", "nope"}, ob, eb), "-out-key is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	assertHelpError(t, keygen([]string{"-out-key", "nope"}, ob, eb), "-out-pub is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
 	args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
-	assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(keyF.Name())
 
 	// failed pub write
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
-	assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// create temp pub file
 	pubF, err := os.CreateTemp("", "test.pub")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(pubF.Name())
 
 	// test proper keygen
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
-	assert.Nil(t, keygen(args, ob, eb))
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	require.NoError(t, keygen(args, ob, eb))
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
+	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
+	assert.Empty(t, b)
+	require.NoError(t, err)
 	assert.Len(t, lKey, 32)
 
 	rb, _ = os.ReadFile(pubF.Name())
-	lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
+	lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
+	assert.Empty(t, b)
+	require.NoError(t, err)
 	assert.Len(t, lPub, 32)
 }

+ 1 - 1
cmd/nebula-cert/main.go

@@ -17,7 +17,7 @@ func (he *helpError) Error() string {
 	return he.s
 }
 
-func newHelpErrorf(s string, v ...interface{}) error {
+func newHelpErrorf(s string, v ...any) error {
 	return &helpError{s: fmt.Sprintf(s, v...)}
 }
 

+ 12 - 4
cmd/nebula-cert/main_test.go

@@ -3,15 +3,15 @@ package main
 import (
 	"bytes"
 	"errors"
+	"fmt"
 	"io"
 	"os"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
-//TODO: all flag parsing continueOnError will print to stderr on its own currently
-
 func Test_help(t *testing.T) {
 	expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
 		"  Global flags:\n" +
@@ -77,8 +77,16 @@ func assertHelpError(t *testing.T, err error, msg string) {
 	case *helpError:
 		// good
 	default:
-		t.Fatal("err was not a helpError")
+		t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
 	}
 
-	assert.EqualError(t, err, msg)
+	require.EqualError(t, err, msg)
+}
+
+func optionalPkcs11String(msg string) string {
+	if p11Supported() {
+		return msg
+	} else {
+		return ""
+	}
 }

+ 15 - 0
cmd/nebula-cert/p11_cgo.go

@@ -0,0 +1,15 @@
+//go:build cgo && pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return true
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key")
+}

+ 16 - 0
cmd/nebula-cert/p11_stub.go

@@ -0,0 +1,16 @@
+//go:build !cgo || !pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return false
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	var ret = ""
+	return &ret
+}

+ 14 - 9
cmd/nebula-cert/print.go

@@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("unable to read cert; %s", err)
 	}
 
-	var c *cert.NebulaCertificate
+	var c cert.Certificate
 	var qrBytes []byte
 	part := 0
 
+	var jsonCerts []cert.Certificate
+
 	for {
-		c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
+		c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
 		if err != nil {
 			return fmt.Errorf("error while unmarshaling cert: %s", err)
 		}
 
 		if *pf.json {
-			b, _ := json.Marshal(c)
-			out.Write(b)
-			out.Write([]byte("\n"))
-
+			jsonCerts = append(jsonCerts, c)
 		} else {
-			out.Write([]byte(c.String()))
-			out.Write([]byte("\n"))
+			_, _ = out.Write([]byte(c.String()))
+			_, _ = out.Write([]byte("\n"))
 		}
 
 		if *pf.outQRPath != "" {
-			b, err := c.MarshalToPEM()
+			b, err := c.MarshalPEM()
 			if err != nil {
 				return fmt.Errorf("error while marshalling cert to PEM: %s", err)
 			}
@@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		part++
 	}
 
+	if *pf.json {
+		b, _ := json.Marshal(jsonCerts)
+		_, _ = out.Write(b)
+		_, _ = out.Write([]byte("\n"))
+	}
+
 	if *pf.outQRPath != "" {
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		if err != nil {

+ 158 - 34
cmd/nebula-cert/print_test.go

@@ -2,12 +2,17 @@ package main
 
 import (
 	"bytes"
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
 	"os"
 	"testing"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func Test_printSummary(t *testing.T) {
@@ -38,84 +43,203 @@ func Test_printCert(t *testing.T) {
 
 	// no path
 	err := printCert([]string{}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 	assertHelpError(t, err, "-path is required")
 
 	// no cert at path
 	ob.Reset()
 	eb.Reset()
 	err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
 
 	// invalid cert at path
 	ob.Reset()
 	eb.Reset()
 	tf, err := os.CreateTemp("", "print-cert")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(tf.Name())
 
 	tf.WriteString("-----BEGIN NOPE-----")
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
 
 	// test multiple certs
 	ob.Reset()
 	eb.Reset()
 	tf.Truncate(0)
 	tf.Seek(0, 0)
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test",
-			Groups:    []string{"hi"},
-			PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-		},
-		Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-	}
+	ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
+	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
 
-	p, _ := c.MarshalToPEM()
+	p, _ := c.MarshalPEM()
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
-	assert.Nil(t, err)
+	fp, _ := c.Fingerprint()
+	pk := hex.EncodeToString(c.PublicKey())
+	sig := hex.EncodeToString(c.Signature())
+	require.NoError(t, err)
 	assert.Equal(
 		t,
-		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
+		//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		`{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+`,
 		ob.String(),
 	)
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
 
 	// test json
 	ob.Reset()
 	eb.Reset()
 	tf.Truncate(0)
 	tf.Seek(0, 0)
-	c = cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test",
-			Groups:    []string{"hi"},
-			PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-		},
-		Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-	}
-
-	p, _ = c.MarshalToPEM()
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 
 	err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
-	assert.Nil(t, err)
+	fp, _ = c.Fingerprint()
+	pk = hex.EncodeToString(c.PublicKey())
+	sig = hex.EncodeToString(c.Signature())
+	require.NoError(t, err)
 	assert.Equal(
 		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
+		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
+`,
 		ob.String(),
 	)
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, eb.String())
+}
+
+// NewTestCaCert will generate a CA cert
+func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
+	var err error
+	if pubKey == nil || privKey == nil {
+		pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+	}
+
+	t := &cert.TBSCertificate{
+		Version:        cert.Version1,
+		Name:           name,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pubKey,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey)
+	if err != nil {
+		panic(err)
+	}
+
+	return c, privKey
+}
+
+func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
+	if before.IsZero() {
+		before = ca.NotBefore()
+	}
+
+	if after.IsZero() {
+		after = ca.NotAfter()
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	pub, rawPriv := x25519Keypair()
+	nc := &cert.TBSCertificate{
+		Version:        cert.Version1,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), signerKey)
+	if err != nil {
+		panic(err)
+	}
+
+	return c, rawPriv
 }

+ 238 - 110
cmd/nebula-cert/sign.go

@@ -3,50 +3,63 @@ package main
 import (
 	"crypto/ecdh"
 	"crypto/rand"
+	"errors"
 	"flag"
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"strings"
 	"time"
 
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 )
 
 type signFlags struct {
-	set         *flag.FlagSet
-	caKeyPath   *string
-	caCertPath  *string
-	name        *string
-	ip          *string
-	duration    *time.Duration
-	inPubPath   *string
-	outKeyPath  *string
-	outCertPath *string
-	outQRPath   *string
-	groups      *string
-	subnets     *string
+	set            *flag.FlagSet
+	version        *uint
+	caKeyPath      *string
+	caCertPath     *string
+	name           *string
+	networks       *string
+	unsafeNetworks *string
+	duration       *time.Duration
+	inPubPath      *string
+	outKeyPath     *string
+	outCertPath    *string
+	outQRPath      *string
+	groups         *string
+
+	p11url *string
+
+	// Deprecated options
+	ip      *string
+	subnets *string
 }
 
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
-	sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
+	sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
+	sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
-	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
-	return &sf
+	sf.p11url = p11Flag(sf.set)
 
+	sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
+	sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
+	return &sf
 }
 
 func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
@@ -56,8 +69,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return err
 	}
 
-	if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
-		return err
+	isP11 := len(*sf.p11url) > 0
+
+	if !isP11 {
+		if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
+			return err
+		}
 	}
 	if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
 		return err
@@ -65,50 +82,67 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	if err := mustFlagString("name", sf.name); err != nil {
 		return err
 	}
-	if err := mustFlagString("ip", sf.ip); err != nil {
-		return err
-	}
-	if *sf.inPubPath != "" && *sf.outKeyPath != "" {
+	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 
-	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
-	if err != nil {
-		return fmt.Errorf("error while reading ca-key: %s", err)
+	var v4Networks []netip.Prefix
+	var v6Networks []netip.Prefix
+	if *sf.networks == "" && *sf.ip != "" {
+		// Pull up deprecated -ip flag if needed
+		*sf.networks = *sf.ip
+	}
+
+	if len(*sf.networks) == 0 {
+		return newHelpErrorf("-networks is required")
+	}
+
+	version := cert.Version(*sf.version)
+	if version != 0 && version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
 	}
 
 	var curve cert.Curve
 	var caKey []byte
 
-	// naively attempt to decode the private key as though it is not encrypted
-	caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
-	if err == cert.ErrPrivateKeyEncrypted {
-		// ask for a passphrase until we get one
-		var passphrase []byte
-		for i := 0; i < 5; i++ {
-			out.Write([]byte("Enter passphrase: "))
-			passphrase, err = pr.ReadPassword()
-
-			if err == ErrNoTerminal {
-				return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
-			} else if err != nil {
-				return fmt.Errorf("error reading password: %s", err)
-			}
+	if !isP11 {
+		var rawCAKey []byte
+		rawCAKey, err := os.ReadFile(*sf.caKeyPath)
 
-			if len(passphrase) > 0 {
-				break
-			}
-		}
-		if len(passphrase) == 0 {
-			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+		if err != nil {
+			return fmt.Errorf("error while reading ca-key: %s", err)
 		}
 
-		curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
-		if err != nil {
-			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+		// naively attempt to decode the private key as though it is not encrypted
+		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
+		if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
+			// ask for a passphrase until we get one
+			var passphrase []byte
+			for i := 0; i < 5; i++ {
+				out.Write([]byte("Enter passphrase: "))
+				passphrase, err = pr.ReadPassword()
+
+				if errors.Is(err, ErrNoTerminal) {
+					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
+				} else if err != nil {
+					return fmt.Errorf("error reading password: %s", err)
+				}
+
+				if len(passphrase) > 0 {
+					break
+				}
+			}
+			if len(passphrase) == 0 {
+				return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+			}
+
+			curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
+			if err != nil {
+				return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+			}
+		} else if err != nil {
+			return fmt.Errorf("error while parsing ca-key: %s", err)
 		}
-	} else if err != nil {
-		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 
 	rawCACert, err := os.ReadFile(*sf.caCertPath)
@@ -116,18 +150,15 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while reading ca-crt: %s", err)
 	}
 
-	caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
+	caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert)
 	if err != nil {
 		return fmt.Errorf("error while parsing ca-crt: %s", err)
 	}
 
-	if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate does not match private key")
-	}
-
-	issuer, err := caCert.Sha256Sum()
-	if err != nil {
-		return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
+	if !isP11 {
+		if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
+			return fmt.Errorf("refusing to sign, root certificate does not match private key")
+		}
 	}
 
 	if caCert.Expired(time.Now()) {
@@ -136,82 +167,99 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 	// if no duration is given, expire one second before the root expires
 	if *sf.duration <= 0 {
-		*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
+		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
 	}
 
-	ip, ipNet, err := net.ParseCIDR(*sf.ip)
-	if err != nil {
-		return newHelpErrorf("invalid ip definition: %s", err)
-	}
-	if ip.To4() == nil {
-		return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
-	}
-	ipNet.IP = ip
+	if *sf.networks != "" {
+		for _, rs := range strings.Split(*sf.networks, ",") {
+			rs := strings.Trim(rs, " ")
+			if rs != "" {
+				n, err := netip.ParsePrefix(rs)
+				if err != nil {
+					return newHelpErrorf("invalid -networks definition: %s", rs)
+				}
 
-	groups := []string{}
-	if *sf.groups != "" {
-		for _, rg := range strings.Split(*sf.groups, ",") {
-			g := strings.TrimSpace(rg)
-			if g != "" {
-				groups = append(groups, g)
+				if n.Addr().Is4() {
+					v4Networks = append(v4Networks, n)
+				} else {
+					v6Networks = append(v6Networks, n)
+				}
 			}
 		}
 	}
 
-	subnets := []*net.IPNet{}
-	if *sf.subnets != "" {
-		for _, rs := range strings.Split(*sf.subnets, ",") {
+	var v4UnsafeNetworks []netip.Prefix
+	var v6UnsafeNetworks []netip.Prefix
+	if *sf.unsafeNetworks == "" && *sf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*sf.unsafeNetworks = *sf.subnets
+	}
+
+	if *sf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				_, s, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
-				if s.IP.To4() == nil {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+
+				if n.Addr().Is4() {
+					v4UnsafeNetworks = append(v4UnsafeNetworks, n)
+				} else {
+					v6UnsafeNetworks = append(v6UnsafeNetworks, n)
 				}
-				subnets = append(subnets, s)
+			}
+		}
+	}
+
+	var groups []string
+	if *sf.groups != "" {
+		for _, rg := range strings.Split(*sf.groups, ",") {
+			g := strings.TrimSpace(rg)
+			if g != "" {
+				groups = append(groups, g)
 			}
 		}
 	}
 
 	var pub, rawPriv []byte
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		curve = cert.Curve_P256
+		p11Client, err = pkclient.FromUrl(*sf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+	}
+
 	if *sf.inPubPath != "" {
+		var pubCurve cert.Curve
 		rawPub, err := os.ReadFile(*sf.inPubPath)
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
-		var pubCurve cert.Curve
-		pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
+
+		pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub)
 		if err != nil {
 			return fmt.Errorf("error while parsing in-pub: %s", err)
 		}
 		if pubCurve != curve {
 			return fmt.Errorf("curve of in-pub does not match ca")
 		}
+	} else if isP11 {
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
+		}
 	} else {
 		pub, rawPriv = newKeypair(curve)
 	}
 
-	nc := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      *sf.name,
-			Ips:       []*net.IPNet{ipNet},
-			Groups:    groups,
-			Subnets:   subnets,
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(*sf.duration),
-			PublicKey: pub,
-			IsCA:      false,
-			Issuer:    issuer,
-			Curve:     curve,
-		},
-	}
-
-	if err := nc.CheckRootConstrains(caCert); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err)
-	}
-
 	if *sf.outKeyPath == "" {
 		*sf.outKeyPath = *sf.name + ".key"
 	}
@@ -224,25 +272,105 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 
-	err = nc.Sign(curve, caKey)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
+	var crts []cert.Certificate
+
+	notBefore := time.Now()
+	notAfter := notBefore.Add(*sf.duration)
+
+	if version == 0 || version == cert.Version1 {
+		// Make sure we at least have an ip
+		if len(v4Networks) != 1 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
+		}
+
+		if version == cert.Version1 {
+			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
+			if len(v6Networks) > 0 {
+				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
+			}
+
+			if len(v6UnsafeNetworks) > 0 {
+				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+			}
+		}
+
+		t := &cert.TBSCertificate{
+			Version:        cert.Version1,
+			Name:           *sf.name,
+			Networks:       []netip.Prefix{v4Networks[0]},
+			Groups:         groups,
+			UnsafeNetworks: v4UnsafeNetworks,
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
+	}
+
+	if version == 0 || version == cert.Version2 {
+		t := &cert.TBSCertificate{
+			Version:        cert.Version2,
+			Name:           *sf.name,
+			Networks:       append(v4Networks, v6Networks...),
+			Groups:         groups,
+			UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
 	}
 
-	if *sf.inPubPath == "" {
+	if !isP11 && *sf.inPubPath == "" {
 		if _, err := os.Stat(*sf.outKeyPath); err == nil {
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}
 
-		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-key: %s", err)
 		}
 	}
 
-	b, err := nc.MarshalToPEM()
-	if err != nil {
-		return fmt.Errorf("error while marshalling certificate: %s", err)
+	var b []byte
+	for _, c := range crts {
+		sb, err := c.MarshalPEM()
+		if err != nil {
+			return fmt.Errorf("error while marshalling certificate: %s", err)
+		}
+		b = append(b, sb...)
 	}
 
 	err = os.WriteFile(*sf.outCertPath, b, 0600)

+ 109 - 109
cmd/nebula-cert/sign_test.go

@@ -13,11 +13,10 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/crypto/ed25519"
 )
 
-//TODO: test file permissions
-
 func Test_signSummary(t *testing.T) {
 	assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
 }
@@ -39,17 +38,24 @@ func Test_signHelp(t *testing.T) {
 			"  -in-pub string\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"  -ip string\n"+
-			"    \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
+			"    \tDeprecated, see -networks\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
+			"  -networks string\n"+
+			"    \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"  -out-key string\n"+
 			"    \tOptional (if in-pub not set): path to write the private key to\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
 		ob.String(),
 	)
 }
@@ -76,20 +82,20 @@ func Test_signCert(t *testing.T) {
 
 	// required args
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
 	), "-name is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
-	), "-ip is required")
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-networks is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// cannot set -in-pub and -out-key
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
 	), "cannot set both -in-pub and -out-key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -97,18 +103,18 @@ func Test_signCert(t *testing.T) {
 	// failed to read key
 	ob.Reset()
 	eb.Reset()
-	args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
+	args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 	// failed to unmarshal key
 	ob.Reset()
 	eb.Reset()
 	caKeyF, err := os.CreateTemp("", "sign-cert.key")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caKeyF.Name())
 
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -116,11 +122,11 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
-	caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv))
+	caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
 
 	// failed to read cert
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -128,30 +134,22 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	caCrtF, err := os.CreateTemp("", "sign-cert.crt")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caCrtF.Name())
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// write a proper ca cert for later
-	ca := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(time.Minute * 200),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	b, _ := ca.MarshalToPEM()
+	ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
+	b, _ := ca.MarshalPEM()
 	caCrtF.Write(b)
 
 	// failed to read pub
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -159,11 +157,11 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	inPubF, err := os.CreateTemp("", "in.pub")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(inPubF.Name())
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -171,116 +169,124 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	inPub, _ := x25519Keypair()
-	inPubF.Write(cert.MarshalX25519PublicKey(inPub))
+	inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub))
 
 	// bad ip cidr
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// bad subnet cidr
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// mismatched ca key
 	_, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
 	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caKeyF2.Name())
-	caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
+	caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	os.Remove(keyF.Name())
 
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	os.Remove(keyF.Name())
 
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	os.Remove(crtF.Name())
 
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb, nopw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.NoError(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
+	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
+	assert.Empty(t, b)
+	require.NoError(t, err)
 	assert.Len(t, lKey, 32)
 
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
-
-	assert.Equal(t, "test", lCrt.Details.Name)
-	assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String())
-	assert.Len(t, lCrt.Details.Ips, 1)
-	assert.False(t, lCrt.Details.IsCA)
-	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
-	assert.Len(t, lCrt.Details.Subnets, 3)
-	assert.Len(t, lCrt.Details.PublicKey, 32)
-	assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
+	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
+	assert.Empty(t, b)
+	require.NoError(t, err)
+
+	assert.Equal(t, "test", lCrt.Name())
+	assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
+	assert.Len(t, lCrt.Networks(), 1)
+	assert.False(t, lCrt.IsCA())
+	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
+	assert.Len(t, lCrt.UnsafeNetworks(), 3)
+	assert.Len(t, lCrt.PublicKey(), 32)
+	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
 
 	sns := []string{}
-	for _, sn := range lCrt.Details.Subnets {
+	for _, sn := range lCrt.UnsafeNetworks() {
 		sns = append(sns, sn.String())
 	}
 	assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
 
-	issuer, _ := ca.Sha256Sum()
-	assert.Equal(t, issuer, lCrt.Details.Issuer)
+	issuer, _ := ca.Fingerprint()
+	assert.Equal(t, issuer, lCrt.Issuer())
 
 	assert.True(t, lCrt.CheckSignature(caPub))
 
@@ -289,53 +295,55 @@ func Test_signCert(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
-	assert.Nil(t, signCert(args, ob, eb, nopw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
+	require.NoError(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// read cert file and check pub key matches in-pub
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
-	assert.Len(t, b, 0)
-	assert.Nil(t, err)
-	assert.Equal(t, lCrt.Details.PublicKey, inPub)
+	lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
+	assert.Empty(t, b)
+	require.NoError(t, err)
+	assert.Equal(t, lCrt.PublicKey(), inPub)
 
 	// test refuse to sign cert with duration beyond root
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb, nopw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.NoError(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb, nopw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.NoError(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -348,11 +356,11 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	caKeyF, err = os.CreateTemp("", "sign-cert.key")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caKeyF.Name())
 
 	caCrtF, err = os.CreateTemp("", "sign-cert.crt")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caCrtF.Name())
 
 	// generate the encrypted key
@@ -361,21 +369,13 @@ func Test_signCert(t *testing.T) {
 	b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
 	caKeyF.Write(b)
 
-	ca = cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(time.Minute * 200),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	b, _ = ca.MarshalToPEM()
+	ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
+	b, _ = ca.MarshalPEM()
 	caCrtF.Write(b)
 
 	// test with the proper password
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Nil(t, signCert(args, ob, eb, testpw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.NoError(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 
@@ -384,8 +384,8 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	testpw.password = []byte("invalid password")
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Error(t, signCert(args, ob, eb, testpw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.Error(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 
@@ -393,8 +393,8 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Error(t, signCert(args, ob, eb, nopw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.Error(t, signCert(args, ob, eb, nopw))
 	// normally the user hitting enter on the prompt would add newlines between these
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
@@ -403,8 +403,8 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.Error(t, signCert(args, ob, eb, errpw))
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	require.Error(t, signCert(args, ob, eb, errpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 }

+ 26 - 15
cmd/nebula-cert/verify.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
-		return fmt.Errorf("error while reading ca: %s", err)
+		return fmt.Errorf("error while reading ca: %w", err)
 	}
 
 	caPool := cert.NewCAPool()
 	for {
-		rawCACert, err = caPool.AddCACertificate(rawCACert)
+		rawCACert, err = caPool.AddCAFromPEM(rawCACert)
 		if err != nil {
-			return fmt.Errorf("error while adding ca cert to pool: %s", err)
+			return fmt.Errorf("error while adding ca cert to pool: %w", err)
 		}
 
 		if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
-		return fmt.Errorf("unable to read crt; %s", err)
+		return fmt.Errorf("unable to read crt: %w", err)
 	}
-
-	c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
-	if err != nil {
-		return fmt.Errorf("error while parsing crt: %s", err)
-	}
-
-	good, err := c.Verify(time.Now(), caPool)
-	if !good {
-		return err
+	var errs []error
+	for {
+		if len(rawCert) == 0 {
+			break
+		}
+		c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
+		if err != nil {
+			return fmt.Errorf("error while parsing crt: %w", err)
+		}
+		rawCert = extra
+		_, err = caPool.VerifyCertificate(time.Now(), c)
+		if err != nil {
+			switch {
+			case errors.Is(err, cert.ErrCaNotFound):
+				errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
+			default:
+				errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
+			}
+		}
 	}
 
-	return nil
+	return errors.Join(errs...)
 }
 
 func verifySummary() string {
@@ -80,7 +91,7 @@ func verifySummary() string {
 
 func verifyHelp(out io.Writer) {
 	vf := newVerifyFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
 	vf.set.SetOutput(out)
 	vf.set.PrintDefaults()
 }

+ 35 - 52
cmd/nebula-cert/verify_test.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 	"golang.org/x/crypto/ed25519"
 )
 
@@ -37,105 +38,87 @@ func Test_verify(t *testing.T) {
 
 	// required args
 	assertHelpError(t, verify([]string{"-ca", "derp"}, ob, eb), "-crt is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	assertHelpError(t, verify([]string{"-crt", "derp"}, ob, eb), "-ca is required")
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
 
 	// no ca at path
 	ob.Reset()
 	eb.Reset()
 	err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
 
 	// invalid ca at path
 	ob.Reset()
 	eb.Reset()
 	caFile, err := os.CreateTemp("", "verify-ca")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(caFile.Name())
 
 	caFile.WriteString("-----BEGIN NOPE-----")
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
 
 	// make a ca for later
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
-	ca := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test-ca",
-			NotBefore: time.Now().Add(time.Hour * -1),
-			NotAfter:  time.Now().Add(time.Hour * 2),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	ca.Sign(cert.Curve_CURVE25519, caPriv)
-	b, _ := ca.MarshalToPEM()
+	ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
+	b, _ := ca.MarshalPEM()
 	caFile.Truncate(0)
 	caFile.Seek(0, 0)
 	caFile.Write(b)
 
 	// no crt at path
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 	// invalid crt at path
 	ob.Reset()
 	eb.Reset()
 	certFile, err := os.CreateTemp("", "verify-cert")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	defer os.Remove(certFile.Name())
 
 	certFile.WriteString("-----BEGIN NOPE-----")
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
 
 	// unverifiable cert at path
-	_, badPriv, _ := ed25519.GenerateKey(rand.Reader)
-	certPub, _ := x25519Keypair()
-	signer, _ := ca.Sha256Sum()
-	crt := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test-cert",
-			NotBefore: time.Now().Add(time.Hour * -1),
-			NotAfter:  time.Now().Add(time.Hour),
-			PublicKey: certPub,
-			IsCA:      false,
-			Issuer:    signer,
-		},
+	crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
+	// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
+	pub := crt.PublicKey()
+	for i, _ := range pub {
+		pub[i] = 0
 	}
-
-	crt.Sign(cert.Curve_CURVE25519, badPriv)
-	b, _ = crt.MarshalToPEM()
+	b, _ = crt.MarshalPEM()
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
 	certFile.Write(b)
 
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "certificate signature did not match")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.ErrorIs(t, err, cert.ErrSignatureMismatch)
 
 	// verified cert at path
-	crt.Sign(cert.Curve_CURVE25519, caPriv)
-	b, _ = crt.MarshalToPEM()
+	crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
+	b, _ = crt.MarshalPEM()
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
 	certFile.Write(b)
 
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
-	assert.Equal(t, "", ob.String())
-	assert.Equal(t, "", eb.String())
-	assert.Nil(t, err)
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+	require.NoError(t, err)
 }

+ 33 - 17
config/config.go

@@ -17,14 +17,14 @@ import (
 
 	"dario.cat/mergo"
 	"github.com/sirupsen/logrus"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
 type C struct {
 	path        string
 	files       []string
-	Settings    map[interface{}]interface{}
-	oldSettings map[interface{}]interface{}
+	Settings    map[string]any
+	oldSettings map[string]any
 	callbacks   []func(*C)
 	l           *logrus.Logger
 	reloadLock  sync.Mutex
@@ -32,7 +32,7 @@ type C struct {
 
 func NewC(l *logrus.Logger) *C {
 	return &C{
-		Settings: make(map[interface{}]interface{}),
+		Settings: make(map[string]any),
 		l:        l,
 	}
 }
@@ -92,8 +92,8 @@ func (c *C) HasChanged(k string) bool {
 	}
 
 	var (
-		nv interface{}
-		ov interface{}
+		nv any
+		ov any
 	)
 
 	if k == "" {
@@ -147,7 +147,7 @@ func (c *C) ReloadConfig() {
 	c.reloadLock.Lock()
 	defer c.reloadLock.Unlock()
 
-	c.oldSettings = make(map[interface{}]interface{})
+	c.oldSettings = make(map[string]any)
 	for k, v := range c.Settings {
 		c.oldSettings[k] = v
 	}
@@ -167,7 +167,7 @@ func (c *C) ReloadConfigString(raw string) error {
 	c.reloadLock.Lock()
 	defer c.reloadLock.Unlock()
 
-	c.oldSettings = make(map[interface{}]interface{})
+	c.oldSettings = make(map[string]any)
 	for k, v := range c.Settings {
 		c.oldSettings[k] = v
 	}
@@ -201,7 +201,7 @@ func (c *C) GetStringSlice(k string, d []string) []string {
 		return d
 	}
 
-	rv, ok := r.([]interface{})
+	rv, ok := r.([]any)
 	if !ok {
 		return d
 	}
@@ -215,13 +215,13 @@ func (c *C) GetStringSlice(k string, d []string) []string {
 }
 
 // GetMap will get the map for k or return the default d if not found or invalid
-func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
+func (c *C) GetMap(k string, d map[string]any) map[string]any {
 	r := c.Get(k)
 	if r == nil {
 		return d
 	}
 
-	v, ok := r.(map[interface{}]interface{})
+	v, ok := r.(map[string]any)
 	if !ok {
 		return d
 	}
@@ -266,6 +266,22 @@ func (c *C) GetBool(k string, d bool) bool {
 	return v
 }
 
+func AsBool(v any) (value bool, ok bool) {
+	switch x := v.(type) {
+	case bool:
+		return x, true
+	case string:
+		switch x {
+		case "y", "yes":
+			return true, true
+		case "n", "no":
+			return false, true
+		}
+	}
+
+	return false, false
+}
+
 // GetDuration will get the duration for k or return the default d if not found or invalid
 func (c *C) GetDuration(k string, d time.Duration) time.Duration {
 	r := c.GetString(k, "")
@@ -276,7 +292,7 @@ func (c *C) GetDuration(k string, d time.Duration) time.Duration {
 	return v
 }
 
-func (c *C) Get(k string) interface{} {
+func (c *C) Get(k string) any {
 	return c.get(k, c.Settings)
 }
 
@@ -284,10 +300,10 @@ func (c *C) IsSet(k string) bool {
 	return c.get(k, c.Settings) != nil
 }
 
-func (c *C) get(k string, v interface{}) interface{} {
+func (c *C) get(k string, v any) any {
 	parts := strings.Split(k, ".")
 	for _, p := range parts {
-		m, ok := v.(map[interface{}]interface{})
+		m, ok := v.(map[string]any)
 		if !ok {
 			return nil
 		}
@@ -346,7 +362,7 @@ func (c *C) addFile(path string, direct bool) error {
 }
 
 func (c *C) parseRaw(b []byte) error {
-	var m map[interface{}]interface{}
+	var m map[string]any
 
 	err := yaml.Unmarshal(b, &m)
 	if err != nil {
@@ -358,7 +374,7 @@ func (c *C) parseRaw(b []byte) error {
 }
 
 func (c *C) parse() error {
-	var m map[interface{}]interface{}
+	var m map[string]any
 
 	for _, path := range c.files {
 		b, err := os.ReadFile(path)
@@ -366,7 +382,7 @@ func (c *C) parse() error {
 			return err
 		}
 
-		var nm map[interface{}]interface{}
+		var nm map[string]any
 		err = yaml.Unmarshal(b, &nm)
 		if err != nil {
 			return err

+ 31 - 34
config/config_test.go

@@ -10,7 +10,7 @@ import (
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
 func TestConfig_Load(t *testing.T) {
@@ -19,40 +19,37 @@ func TestConfig_Load(t *testing.T) {
 	// invalid yaml
 	c := NewC(l)
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
-	assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
+	require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n  line 1: cannot unmarshal !!str `invalid...` into map[string]interface {}")
 
 	// simple multi config merge
 	c = NewC(l)
 	os.RemoveAll(dir)
 	os.Mkdir(dir, 0755)
 
-	assert.Nil(t, err)
+	require.NoError(t, err)
 
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 	os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n  inner: override\nnew: hi"), 0644)
-	assert.Nil(t, c.Load(dir))
-	expected := map[interface{}]interface{}{
-		"outer": map[interface{}]interface{}{
+	require.NoError(t, c.Load(dir))
+	expected := map[string]any{
+		"outer": map[string]any{
 			"inner": "override",
 		},
 		"new": "hi",
 	}
 	assert.Equal(t, expected, c.Settings)
-
-	//TODO: test symlinked file
-	//TODO: test symlinked directory
 }
 
 func TestConfig_Get(t *testing.T) {
 	l := test.NewLogger()
 	// test simple type
 	c := NewC(l)
-	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
+	c.Settings["firewall"] = map[string]any{"outbound": "hi"}
 	assert.Equal(t, "hi", c.Get("firewall.outbound"))
 
 	// test complex type
-	inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
-	c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
+	inner := []map[string]any{{"port": "1", "code": "2"}}
+	c.Settings["firewall"] = map[string]any{"outbound": inner}
 	assert.EqualValues(t, inner, c.Get("firewall.outbound"))
 
 	// test missing
@@ -62,7 +59,7 @@ func TestConfig_Get(t *testing.T) {
 func TestConfig_GetStringSlice(t *testing.T) {
 	l := test.NewLogger()
 	c := NewC(l)
-	c.Settings["slice"] = []interface{}{"one", "two"}
+	c.Settings["slice"] = []any{"one", "two"}
 	assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
 }
 
@@ -70,28 +67,28 @@ func TestConfig_GetBool(t *testing.T) {
 	l := test.NewLogger()
 	c := NewC(l)
 	c.Settings["bool"] = true
-	assert.Equal(t, true, c.GetBool("bool", false))
+	assert.True(t, c.GetBool("bool", false))
 
 	c.Settings["bool"] = "true"
-	assert.Equal(t, true, c.GetBool("bool", false))
+	assert.True(t, c.GetBool("bool", false))
 
 	c.Settings["bool"] = false
-	assert.Equal(t, false, c.GetBool("bool", true))
+	assert.False(t, c.GetBool("bool", true))
 
 	c.Settings["bool"] = "false"
-	assert.Equal(t, false, c.GetBool("bool", true))
+	assert.False(t, c.GetBool("bool", true))
 
 	c.Settings["bool"] = "Y"
-	assert.Equal(t, true, c.GetBool("bool", false))
+	assert.True(t, c.GetBool("bool", false))
 
 	c.Settings["bool"] = "yEs"
-	assert.Equal(t, true, c.GetBool("bool", false))
+	assert.True(t, c.GetBool("bool", false))
 
 	c.Settings["bool"] = "N"
-	assert.Equal(t, false, c.GetBool("bool", true))
+	assert.False(t, c.GetBool("bool", true))
 
 	c.Settings["bool"] = "nO"
-	assert.Equal(t, false, c.GetBool("bool", true))
+	assert.False(t, c.GetBool("bool", true))
 }
 
 func TestConfig_HasChanged(t *testing.T) {
@@ -104,14 +101,14 @@ func TestConfig_HasChanged(t *testing.T) {
 	// Test key change
 	c = NewC(l)
 	c.Settings["test"] = "hi"
-	c.oldSettings = map[interface{}]interface{}{"test": "no"}
+	c.oldSettings = map[string]any{"test": "no"}
 	assert.True(t, c.HasChanged("test"))
 	assert.True(t, c.HasChanged(""))
 
 	// No key change
 	c = NewC(l)
 	c.Settings["test"] = "hi"
-	c.oldSettings = map[interface{}]interface{}{"test": "hi"}
+	c.oldSettings = map[string]any{"test": "hi"}
 	assert.False(t, c.HasChanged("test"))
 	assert.False(t, c.HasChanged(""))
 }
@@ -120,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
 	l := test.NewLogger()
 	done := make(chan bool, 1)
 	dir, err := os.MkdirTemp("", "config-test")
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n  inner: hi"), 0644)
 
 	c := NewC(l)
-	assert.Nil(t, c.Load(dir))
+	require.NoError(t, c.Load(dir))
 
 	assert.False(t, c.HasChanged("outer.inner"))
 	assert.False(t, c.HasChanged("outer"))
@@ -187,11 +184,11 @@ firewall:
 `),
 	}
 
-	var m map[any]any
+	var m map[string]any
 
 	// merge the same way config.parse() merges
 	for _, b := range configs {
-		var nm map[any]any
+		var nm map[string]any
 		err := yaml.Unmarshal(b, &nm)
 		require.NoError(t, err)
 
@@ -208,15 +205,15 @@ firewall:
 	t.Logf("Merged Config as YAML:\n%s", mYaml)
 
 	// If a bug is present, some items might be replaced instead of merged like we expect
-	expected := map[any]any{
-		"firewall": map[any]any{
+	expected := map[string]any{
+		"firewall": map[string]any{
 			"inbound": []any{
-				map[any]any{"host": "any", "port": "any", "proto": "icmp"},
-				map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
-				map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
+				map[string]any{"host": "any", "port": "any", "proto": "icmp"},
+				map[string]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
+				map[string]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
 			"outbound": []any{
-				map[any]any{"host": "any", "port": "any", "proto": "any"}}},
-		"listen": map[any]any{
+				map[string]any{"host": "any", "port": "any", "proto": "any"}}},
+		"listen": map[string]any{
 			"host": "0.0.0.0",
 			"port": 4242,
 		},

+ 67 - 38
connection_manager.go

@@ -3,14 +3,14 @@ package nebula
 import (
 	"bytes"
 	"context"
+	"encoding/binary"
+	"net/netip"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 type trafficDecision int
@@ -182,7 +182,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 	case deleteTunnel:
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 
 	case closeTunnel:
@@ -220,11 +220,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 	for _, r := range relayFor {
-		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
 
 		var index uint32
-		var relayFrom iputil.VpnIp
-		var relayTo iputil.VpnIp
+		var relayFrom netip.Addr
+		var relayTo netip.Addr
 		switch {
 		case ok && existing.State == Established:
 			// This relay already exists in newhostinfo, then do nothing.
@@ -234,11 +234,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnIp
-				relayTo = existing.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = existing.PeerAddr
 			case ForwardingType:
-				relayFrom = existing.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = existing.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
 			}
@@ -252,18 +252,18 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			n.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnIp
-				relayTo = r.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = r.PeerAddr
 			case ForwardingType:
-				relayFrom = r.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = r.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
 			}
@@ -273,20 +273,43 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         uint32(relayFrom),
-			RelayToIp:           uint32(relayTo),
 		}
+
+		switch newhostinfo.GetCert().Certificate.Version() {
+		case cert.Version1:
+			if !relayFrom.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				continue
+			}
+
+			if !relayTo.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				continue
+			}
+
+			b := relayFrom.As4()
+			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = relayTo.As4()
+			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		case cert.Version2:
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
+			req.RelayToAddr = netAddrToProtoAddr(relayTo)
+		default:
+			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			continue
+		}
+
 		msg, err := req.Marshal()
 		if err != nil {
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(req.RelayFromIp),
-				"relayTo":             iputil.VpnIp(req.RelayToIp),
+				"relayFrom":           req.RelayFromAddr,
+				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
-				"vpnIp":               newhostinfo.vpnIp}).
+				"vpnAddrs":            newhostinfo.vpnAddrs}).
 				Info("send CreateRelayRequest")
 		}
 	}
@@ -308,7 +331,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		return closeTunnel, hostinfo, nil
 	}
 
-	primary := n.hostMap.Hosts[hostinfo.vpnIp]
+	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
@@ -402,21 +425,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 
-	if current.vpnIp < n.intf.myVpnIp {
-		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
-		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
-		// The remotes vpn ip is lower than mine. I will not flip.
+	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
+	// vpn addr is static across all tunnels for this host pair so lets
+	// use that to determine if we should consider swapping.
+	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
+		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 	}
 
-	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
+	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
+	// settle down.
+	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 	n.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnIp] == primary {
+	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
 		n.hostMap.unlockedMakePrimary(current)
 	}
 	n.hostMap.Unlock()
@@ -431,8 +457,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
-	if valid {
+	caPool := n.intf.pki.GetCAPool()
+	err := caPool.VerifyCachedCertificate(now, remoteCert)
+	if err == nil {
 		return false
 	}
 
@@ -441,9 +468,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	fingerprint, _ := remoteCert.Sha256Sum()
 	hostinfo.logger(n.l).WithError(err).
-		WithField("fingerprint", fingerprint).
+		WithField("fingerprint", remoteCert.Fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
 	return true
@@ -456,26 +482,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 	}
 
 	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
+		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
 			n.metricsTxPunchy.Inc(1)
 			n.intf.outside.WriteTo([]byte{1}, addr)
 		})
 
-	} else if hostinfo.remote != nil {
+	} else if hostinfo.remote.IsValid() {
 		n.metricsTxPunchy.Inc(1)
 		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 }
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
+	cs := n.intf.pki.getCertState()
+	curCrt := hostinfo.ConnectionState.myCert
+	myCrt := cs.getCertificate(curCrt.Version())
+	if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}
 
-	n.l.WithField("vpnIp", hostinfo.vpnIp).
+	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }

+ 178 - 76
connection_manager_test.go

@@ -4,29 +4,27 @@ import (
 	"context"
 	"crypto/ed25519"
 	"crypto/rand"
-	"net"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
-var vpnIp iputil.VpnIp
-
 func newTestLighthouse() *LightHouse {
 	lh := &LightHouse{
 		l:         test.NewLogger(),
-		addrMap:   map[iputil.VpnIp]*RemoteList{},
-		queryChan: make(chan iputil.VpnIp, 10),
+		addrMap:   map[netip.Addr]*RemoteList{},
+		queryChan: make(chan netip.Addr, 10),
 	}
-	lighthouses := map[iputil.VpnIp]struct{}{}
-	staticList := map[iputil.VpnIp]struct{}{}
+	lighthouses := map[netip.Addr]struct{}{}
+	staticList := map[netip.Addr]struct{}{}
 
 	lh.lighthouses.Store(&lighthouses)
 	lh.staticList.Store(&staticList)
@@ -37,20 +35,19 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -77,12 +74,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &cert.NebulaCertificate{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -91,7 +88,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 
@@ -108,31 +105,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -159,12 +156,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &cert.NebulaCertificate{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -172,8 +169,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@@ -189,7 +186,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
@@ -198,7 +195,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
 
 // Check if we can disconnect the peer.
@@ -207,54 +204,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	now := time.Now()
 	l := test.NewLogger()
-	ipNet := net.IPNet{
-		IP:   net.IPv4(172, 1, 1, 2),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
-	hostMap := newHostMap(l, vpncidr)
+
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	// Generate keys for CA and peer's cert.
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
-	caCert := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: now,
-			NotAfter:  now.Add(1 * time.Hour),
-			IsCA:      true,
-			PublicKey: pubCA,
-		},
+	tbs := &cert.TBSCertificate{
+		Version:   1,
+		Name:      "ca",
+		IsCA:      true,
+		NotBefore: now,
+		NotAfter:  now.Add(1 * time.Hour),
+		PublicKey: pubCA,
 	}
 
-	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
-	ncp := &cert.NebulaCAPool{
-		CAs: cert.NewCAPool().CAs,
-	}
-	ncp.CAs["ca"] = &caCert
+	caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
+	require.NoError(t, err)
+	ncp := cert.NewCAPool()
+	require.NoError(t, ncp.AddCA(caCert))
 
 	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
-	peerCert := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "host",
-			Ips:       []*net.IPNet{&ipNet},
-			Subnets:   []*net.IPNet{},
-			NotBefore: now,
-			NotAfter:  now.Add(60 * time.Second),
-			PublicKey: pubCrt,
-			IsCA:      false,
-			Issuer:    "ca",
-		},
+	tbs = &cert.TBSCertificate{
+		Version:   1,
+		Name:      "host",
+		Networks:  []netip.Prefix{vpncidr},
+		NotBefore: now,
+		NotAfter:  now.Add(60 * time.Second),
+		PublicKey: pubCrt,
 	}
-	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
+	peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
+	require.NoError(t, err)
+
+	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -280,10 +271,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.connectionManager = nc
 
 	hostinfo := &HostInfo{
-		vpnIp: vpnIp,
+		vpnAddrs: []netip.Addr{vpnIp},
 		ConnectionState: &ConnectionState{
-			myCert:   &cert.NebulaCertificate{},
-			peerCert: &peerCert,
+			myCert:   &dummyCert{},
+			peerCert: cachedPeerCert,
 			H:        &noise.HandshakeState{},
 		},
 	}
@@ -303,3 +294,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
 	assert.True(t, invalid)
 }
+
+type dummyCert struct {
+	version        cert.Version
+	curve          cert.Curve
+	groups         []string
+	isCa           bool
+	issuer         string
+	name           string
+	networks       []netip.Prefix
+	notAfter       time.Time
+	notBefore      time.Time
+	publicKey      []byte
+	signature      []byte
+	unsafeNetworks []netip.Prefix
+}
+
+func (d *dummyCert) Version() cert.Version {
+	return d.version
+}
+
+func (d *dummyCert) Curve() cert.Curve {
+	return d.curve
+}
+
+func (d *dummyCert) Groups() []string {
+	return d.groups
+}
+
+func (d *dummyCert) IsCA() bool {
+	return d.isCa
+}
+
+func (d *dummyCert) Issuer() string {
+	return d.issuer
+}
+
+func (d *dummyCert) Name() string {
+	return d.name
+}
+
+func (d *dummyCert) Networks() []netip.Prefix {
+	return d.networks
+}
+
+func (d *dummyCert) NotAfter() time.Time {
+	return d.notAfter
+}
+
+func (d *dummyCert) NotBefore() time.Time {
+	return d.notBefore
+}
+
+func (d *dummyCert) PublicKey() []byte {
+	return d.publicKey
+}
+
+func (d *dummyCert) Signature() []byte {
+	return d.signature
+}
+
+func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
+	return d.unsafeNetworks
+}
+
+func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
+	return nil
+}
+
+func (d *dummyCert) CheckSignature(key []byte) bool {
+	return true
+}
+
+func (d *dummyCert) Expired(t time.Time) bool {
+	return false
+}
+
+func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
+	return nil
+}
+
+func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
+	return nil
+}
+
+func (d *dummyCert) String() string {
+	return ""
+}
+
+func (d *dummyCert) Marshal() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) MarshalPEM() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Fingerprint() (string, error) {
+	return "", nil
+}
+
+func (d *dummyCert) MarshalJSON() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Copy() cert.Certificate {
+	return d
+}

+ 35 - 23
connection_state.go

@@ -3,6 +3,7 @@ package nebula
 import (
 	"crypto/rand"
 	"encoding/json"
+	"fmt"
 	"sync/atomic"
 
 	"github.com/flynn/noise"
@@ -17,50 +18,54 @@ type ConnectionState struct {
 	eKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	H              *noise.HandshakeState
-	myCert         *cert.NebulaCertificate
-	peerCert       *cert.NebulaCertificate
+	myCert         cert.Certificate
+	peerCert       *cert.CachedCertificate
 	initiator      bool
 	messageCounter atomic.Uint64
 	window         *Bits
 	writeLock      syncMutex
 }
 
-func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
 	var dhFunc noise.DHFunc
-	switch certState.Certificate.Details.Curve {
+	switch crt.Curve() {
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
-		dhFunc = noiseutil.DHP256
+		if cs.pkcs11Backed {
+			dhFunc = noiseutil.DHP256PKCS11
+		} else {
+			dhFunc = noiseutil.DHP256
+		}
 	default:
-		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
-		return nil
+		return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
 	}
 
-	var cs noise.CipherSuite
-	if cipher == "chachapoly" {
-		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	var ncs noise.CipherSuite
+	if cs.cipher == "chachapoly" {
+		ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	} else {
-		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
+		ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 
-	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
+	static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
 
 	b := NewBits(ReplayWindow)
-	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
+	// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
 	b.Update(l, 0)
 
 	hs, err := noise.NewHandshakeState(noise.Config{
-		CipherSuite:           cs,
-		Random:                rand.Reader,
-		Pattern:               pattern,
-		Initiator:             initiator,
-		StaticKeypair:         static,
-		PresharedKey:          psk,
-		PresharedKeyPlacement: pskStage,
+		CipherSuite:   ncs,
+		Random:        rand.Reader,
+		Pattern:       pattern,
+		Initiator:     initiator,
+		StaticKeypair: static,
+		//NOTE: These should come from CertState (pki.go) when we finally implement it
+		PresharedKey:          []byte{},
+		PresharedKeyPlacement: 0,
 	})
 	if err != nil {
-		return nil
+		return nil, fmt.Errorf("NewConnectionState: %s", err)
 	}
 
 	// The queue and ready params prevent a counter race that would happen when
@@ -69,11 +74,14 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		initiator: initiator,
 		window:    b,
-		myCert:    certState.Certificate,
+		myCert:    crt,
+
 		writeLock: newSyncMutex("connection-state-write"),
 	}
+	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
+	ci.messageCounter.Add(2)
 
-	return ci
+	return ci, nil
 }
 
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@@ -83,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"message_counter": cs.messageCounter.Load(),
 	})
 }
+
+func (cs *ConnectionState) Curve() cert.Curve {
+	return cs.myCert.Curve()
+}

+ 78 - 43
control.go

@@ -2,7 +2,7 @@ package nebula
 
 import (
 	"context"
-	"net"
+	"net/netip"
 	"os"
 	"os/signal"
 	"syscall"
@@ -10,9 +10,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
-	"github.com/slackhq/nebula/udp"
 )
 
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -21,10 +19,10 @@ import (
 type controlEach func(h *HostInfo)
 
 type controlHostLister interface {
-	QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+	QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
-	ForEachVpnIp(each controlEach)
-	GetPreferredRanges() []*net.IPNet
+	ForEachVpnAddr(each controlEach)
+	GetPreferredRanges() []netip.Prefix
 }
 
 type Control struct {
@@ -39,15 +37,15 @@ type Control struct {
 }
 
 type ControlHostInfo struct {
-	VpnIp                  net.IP                  `json:"vpnIp"`
-	LocalIndex             uint32                  `json:"localIndex"`
-	RemoteIndex            uint32                  `json:"remoteIndex"`
-	RemoteAddrs            []*udp.Addr             `json:"remoteAddrs"`
-	Cert                   *cert.NebulaCertificate `json:"cert"`
-	MessageCounter         uint64                  `json:"messageCounter"`
-	CurrentRemote          *udp.Addr               `json:"currentRemote"`
-	CurrentRelaysToMe      []iputil.VpnIp          `json:"currentRelaysToMe"`
-	CurrentRelaysThroughMe []iputil.VpnIp          `json:"currentRelaysThroughMe"`
+	VpnAddrs               []netip.Addr     `json:"vpnAddrs"`
+	LocalIndex             uint32           `json:"localIndex"`
+	RemoteIndex            uint32           `json:"remoteIndex"`
+	RemoteAddrs            []netip.AddrPort `json:"remoteAddrs"`
+	Cert                   cert.Certificate `json:"cert"`
+	MessageCounter         uint64           `json:"messageCounter"`
+	CurrentRemote          netip.AddrPort   `json:"currentRemote"`
+	CurrentRelaysToMe      []netip.Addr     `json:"currentRelaysToMe"`
+	CurrentRelaysThroughMe []netip.Addr     `json:"currentRelaysThroughMe"`
 }
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -131,8 +129,48 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 	}
 }
 
-// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
+// GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
+func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
+	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
+	if found {
+		// Only returning the default certificate since its impossible
+		// for any other host but ourselves to have more than 1
+		return c.f.pki.getCertState().GetDefaultCertificate().Copy()
+	}
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
+	if hi == nil {
+		return nil
+	}
+	return hi.GetCert().Certificate.Copy()
+}
+
+// CreateTunnel creates a new tunnel to the given vpn ip.
+func (c *Control) CreateTunnel(vpnIp netip.Addr) {
+	c.f.handshakeManager.StartHandshake(vpnIp, nil)
+}
+
+// PrintTunnel creates a new tunnel to the given vpn ip.
+func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
+	if hi == nil {
+		return nil
+	}
+	chi := copyHostInfo(hi, c.f.hostMap.GetPreferredRanges())
+	return &chi
+}
+
+// QueryLighthouse queries the lighthouse.
+func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
+	hi := c.f.lightHouse.Query(vpnIp)
+	if hi == nil {
+		return nil
+	}
+	return hi.CopyCache()
+}
+
+// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	if pending {
 		hl = c.f.handshakeManager
@@ -140,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 		hl = c.f.hostMap
 	}
 
-	h := hl.QueryVpnIp(vpnIp)
+	h := hl.QueryVpnAddr(vpnAddr)
 	if h == nil {
 		return nil
 	}
@@ -150,20 +188,22 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 }
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 		return nil
 	}
 
-	hostInfo.SetRemote(addr.Copy())
+	hostInfo.SetRemote(addr)
 	ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
 	return &ch
 }
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 		return false
 	}
@@ -187,29 +227,24 @@ func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
 // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
 // the int returned is a count of tunnels closed
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
-	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
-	lighthouses := c.f.lightHouse.GetLighthouses()
-
 	shutdown := func(h *HostInfo) {
-		if excludeLighthouses {
-			if _, ok := lighthouses[h.vpnIp]; ok {
-				return
-			}
+		if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
+			return
 		}
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.closeTunnel(h)
 
-		c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
+		c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
 			Debug("Sending close tunnel message")
 		closed++
 	}
 
 	// Learn which hosts are being used as relays, so we can shut them down last.
-	relayingHosts := map[iputil.VpnIp]*HostInfo{}
+	relayingHosts := map[netip.Addr]*HostInfo{}
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
-		relayingHosts[relayingHost.vpnIp] = relayingHost
+		relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
 	}
 	c.f.hostMap.Unlock()
 
@@ -217,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Hosts map
 	c.f.hostMap.Lock()
 	for _, relayHost := range c.f.hostMap.Indexes {
-		if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
+		if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
 			hostInfos = append(hostInfos, relayHost)
 		}
 	}
@@ -236,15 +271,19 @@ func (c *Control) Device() overlay.Device {
 	return c.f.inside
 }
 
-func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
-
+func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp.ToIP(),
+		VpnAddrs:               make([]netip.Addr, len(h.vpnAddrs)),
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
+		CurrentRemote:          h.remote,
+	}
+
+	for i, a := range h.vpnAddrs {
+		chi.VpnAddrs[i] = a
 	}
 
 	if h.ConnectionState != nil {
@@ -252,11 +291,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	}
 
 	if c := h.GetCert(); c != nil {
-		chi.Cert = c.Copy()
-	}
-
-	if h.remote != nil {
-		chi.CurrentRemote = h.remote.Copy()
+		chi.Cert = c.Certificate.Copy()
 	}
 
 	return chi
@@ -265,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 	hosts := make([]ControlHostInfo, 0)
 	pr := hl.GetPreferredRanges()
-	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+	hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 	})
 	return hosts

+ 40 - 45
control_test.go

@@ -2,72 +2,66 @@ package nebula
 
 import (
 	"net"
+	"net/netip"
 	"reflect"
 	"testing"
-	"time"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
-	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestControl_GetHostInfoByVpnIp(t *testing.T) {
+	//TODO: CERT-V2 with multiple certificate versions we have a problem with this test
+	// Some certs versions have different characteristics and each version implements their own Copy() func
+	// which means this is not a good place to test for exposing memory
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, &net.IPNet{})
-	hm.preferredRanges.Store(&[]*net.IPNet{})
+	hm := newHostMap(l)
+	hm.preferredRanges.Store(&[]netip.Prefix{})
+
+	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
+	remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
 
-	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
-	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
+		IP:   remote1.Addr().AsSlice(),
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 
 	ipNet2 := net.IPNet{
-		IP:   net.ParseIP("1:2:3:4:5:6:7:8"),
+		IP:   remote2.Addr().AsSlice(),
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 
-	crt := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test",
-			Ips:            []*net.IPNet{&ipNet},
-			Subnets:        []*net.IPNet{},
-			Groups:         []string{"default-group"},
-			NotBefore:      time.Unix(1, 0),
-			NotAfter:       time.Unix(2, 0),
-			PublicKey:      []byte{5, 6, 7, 8},
-			IsCA:           false,
-			Issuer:         "the-issuer",
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-		},
-		Signature: []byte{1, 2, 1, 2, 1, 3},
-	}
+	remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
 
-	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
-	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
+	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
+	assert.True(t, ok)
+
+	crt := &dummyCert{}
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
-			peerCert: crt,
+			peerCert: &cert.CachedCertificate{Certificate: crt},
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         iputil.Ip2VpnIp(ipNet.IP),
+		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
+	vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
+	assert.True(t, ok)
+
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
@@ -76,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         iputil.Ip2VpnIp(ipNet2.IP),
+		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
@@ -91,31 +85,32 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 	}
 
-	thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
+	thi := c.GetHostInfoByVpnAddr(vpnIp, false)
 
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  net.IPv4(1, 2, 3, 4).To4(),
+		VpnAddrs:               []netip.Addr{vpnIp},
 		LocalIndex:             201,
 		RemoteIndex:            200,
-		RemoteAddrs:            []*udp.Addr{remote2, remote1},
+		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
 		Cert:                   crt.Copy(),
 		MessageCounter:         0,
-		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
-		CurrentRelaysToMe:      []iputil.VpnIp{},
-		CurrentRelaysThroughMe: []iputil.VpnIp{},
+		CurrentRemote:          remote1,
+		CurrentRelaysToMe:      []netip.Addr{},
+		CurrentRelaysThroughMe: []netip.Addr{},
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assert.Equal(t, &expectedInfo, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
+		thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
 	})
 }
 
-func assertFields(t *testing.T, expected []string, actualStruct interface{}) {
+func assertFields(t *testing.T, expected []string, actualStruct any) {
 	val := reflect.ValueOf(actualStruct).Elem()
 	fields := make([]string, val.NumField())
 	for i := 0; i < val.NumField(); i++ {

+ 51 - 37
control_tester.go

@@ -4,14 +4,11 @@
 package nebula
 
 import (
-	"net"
-
-	"github.com/slackhq/nebula/cert"
+	"net/netip"
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
@@ -50,37 +47,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // This is necessary if you did not configure static hosts or are not running a lighthouse
-func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
+func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	iVpnIp := iputil.Ip2VpnIp(vpnIp)
-	if v4 := toAddr.IP.To4(); v4 != nil {
-		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+	if toAddr.Addr().Is4() {
+		remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
 	} else {
-		remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+		remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
 	}
 }
 
 // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
 // This is necessary to inform an initiator of possible relays for communicating with a responder
-func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) {
+func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	iVpnIp := iputil.Ip2VpnIp(vpnIp)
-	uVpnIp := []uint32{}
-	for _, rVPnIp := range relayVpnIps {
-		uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
-	}
-
-	remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
+	remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
 }
 
 // GetFromTun will pull a packet off the tun side of nebula
@@ -107,20 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
-	ip := layers.IPv4{
-		Version:  4,
-		TTL:      64,
-		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().IP,
-		DstIP:    toIp,
+func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
+	serialize := make([]gopacket.SerializableLayer, 0)
+	var netLayer gopacket.NetworkLayer
+	if toAddr.Is6() {
+		if !fromAddr.Is6() {
+			panic("Cant send ipv6 to ipv4")
+		}
+		ip := &layers.IPv6{
+			Version:    6,
+			NextHeader: layers.IPProtocolUDP,
+			SrcIP:      fromAddr.Unmap().AsSlice(),
+			DstIP:      toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
+	} else {
+		if !fromAddr.Is4() {
+			panic("Cant send ipv4 to ipv6")
+		}
+
+		ip := &layers.IPv4{
+			Version:  4,
+			TTL:      64,
+			Protocol: layers.IPProtocolUDP,
+			SrcIP:    fromAddr.Unmap().AsSlice(),
+			DstIP:    toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
 	}
 
 	udp := layers.UDP{
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 	}
-	err := udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(netLayer)
 	if err != nil {
 		panic(err)
 	}
@@ -130,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 		ComputeChecksums: true,
 		FixLengths:       true,
 	}
-	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+
+	serialize = append(serialize, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, serialize...)
 	if err != nil {
 		panic(err)
 	}
@@ -138,16 +152,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 
-func (c *Control) GetVpnIp() iputil.VpnIp {
-	return c.f.myVpnIp
+func (c *Control) GetVpnAddrs() []netip.Addr {
+	return c.f.myVpnAddrs
 }
 
-func (c *Control) GetUDPAddr() string {
-	return c.f.outside.(*udp.TesterConn).Addr.String()
+func (c *Control) GetUDPAddr() netip.AddrPort {
+	return c.f.outside.(*udp.TesterConn).Addr
 }
 
-func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
+	hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
 	if hostinfo == nil {
 		return false
 	}
@@ -160,10 +174,10 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 }
 
-func (c *Control) GetCert() *cert.NebulaCertificate {
-	return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertState() *CertState {
+	return c.f.pki.getCertState()
 }
 
-func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+func (c *Control) ReHandshake(vpnIp netip.Addr) {
 	c.f.handshakeManager.StartHandshake(vpnIp, nil)
 }

+ 86 - 40
dns_server.go

@@ -3,13 +3,14 @@ package nebula
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"strconv"
 	"strings"
 
+	"github.com/gaissmai/bart"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 // This whole thing should be rewritten to use context
@@ -20,74 +21,121 @@ var dnsAddr string
 
 type dnsRecords struct {
 	syncRWMutex
-	dnsMap  map[string]string
-	hostMap *HostMap
+	l               *logrus.Logger
+	dnsMap4         map[string]netip.Addr
+	dnsMap6         map[string]netip.Addr
+	hostMap         *HostMap
+	myVpnAddrsTable *bart.Table[struct{}]
 }
 
-func newDnsRecords(hostMap *HostMap) *dnsRecords {
+func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
 	return &dnsRecords{
-		syncRWMutex: newSyncRWMutex("dns-records"),
-		dnsMap:      make(map[string]string),
-		hostMap:     hostMap,
+		syncRWMutex:     newSyncRWMutex("dns-records"),
+		l:               l,
+		dnsMap4:         make(map[string]netip.Addr),
+		dnsMap6:         make(map[string]netip.Addr),
+		hostMap:         hostMap,
+		myVpnAddrsTable: cs.myVpnAddrsTable,
 	}
 }
 
-func (d *dnsRecords) Query(data string) string {
+func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
+	data = strings.ToLower(data)
 	d.RLock()
 	defer d.RUnlock()
-	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
-		return r
+	switch q {
+	case dns.TypeA:
+		if r, ok := d.dnsMap4[data]; ok {
+			return r
+		}
+	case dns.TypeAAAA:
+		if r, ok := d.dnsMap6[data]; ok {
+			return r
+		}
 	}
-	return ""
+
+	return netip.Addr{}
 }
 
 func (d *dnsRecords) QueryCert(data string) string {
-	ip := net.ParseIP(data[:len(data)-1])
-	if ip == nil {
+	ip, err := netip.ParseAddr(data[:len(data)-1])
+	if err != nil {
 		return ""
 	}
-	iip := iputil.Ip2VpnIp(ip)
-	hostinfo := d.hostMap.QueryVpnIp(iip)
+
+	hostinfo := d.hostMap.QueryVpnAddr(ip)
 	if hostinfo == nil {
 		return ""
 	}
+
 	q := hostinfo.GetCert()
 	if q == nil {
 		return ""
 	}
-	cert := q.Details
-	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
-	return c
+
+	b, err := q.Certificate.MarshalJSON()
+	if err != nil {
+		return ""
+	}
+	return string(b)
 }
 
-func (d *dnsRecords) Add(host, data string) {
+// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
+func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
+	host = strings.ToLower(host)
 	d.Lock()
 	defer d.Unlock()
-	d.dnsMap[strings.ToLower(host)] = data
+	haveV4 := false
+	haveV6 := false
+	for _, addr := range addresses {
+		if addr.Is4() && !haveV4 {
+			d.dnsMap4[host] = addr
+			haveV4 = true
+		} else if addr.Is6() && !haveV6 {
+			d.dnsMap6[host] = addr
+			haveV6 = true
+		}
+		if haveV4 && haveV6 {
+			break
+		}
+	}
 }
 
-func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
+func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
+	a, _, _ := net.SplitHostPort(addr)
+	b, err := netip.ParseAddr(a)
+	if err != nil {
+		return false
+	}
+
+	if b.IsLoopback() {
+		return true
+	}
+
+	_, found := d.myVpnAddrsTable.Lookup(b)
+	return found //if we found it in this table, it's good
+}
+
+func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 	for _, q := range m.Question {
 		switch q.Qtype {
-		case dns.TypeA:
-			l.Debugf("Query for A %s", q.Name)
-			ip := dnsR.Query(q.Name)
-			if ip != "" {
-				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
+		case dns.TypeA, dns.TypeAAAA:
+			qType := dns.TypeToString[q.Qtype]
+			d.l.Debugf("Query for %s %s", qType, q.Name)
+			ip := d.Query(q.Qtype, q.Name)
+			if ip.IsValid() {
+				rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
 				if err == nil {
 					m.Answer = append(m.Answer, rr)
 				}
 			}
 		case dns.TypeTXT:
-			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b := net.ParseIP(a)
-			// We don't answer these queries from non nebula nodes or localhost
-			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
-			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
+			// We only answer these queries from nebula nodes or localhost
+			if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
 				return
 			}
-			l.Debugf("Query for TXT %s", q.Name)
-			ip := dnsR.QueryCert(q.Name)
+			d.l.Debugf("Query for TXT %s", q.Name)
+			ip := d.QueryCert(q.Name)
 			if ip != "" {
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				if err == nil {
@@ -102,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 	}
 }
 
-func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
+func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.Compress = false
 
 	switch r.Opcode {
 	case dns.OpcodeQuery:
-		parseQuery(l, m, w)
+		d.parseQuery(m, w)
 	}
 
 	w.WriteMsg(m)
 }
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
-	dnsR = newDnsRecords(hostMap)
+func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
+	dnsR = newDnsRecords(l, cs, hostMap)
 
 	// attach request handler func
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		handleDnsRequest(l, w, r)
-	})
+	dns.HandleFunc(".", dnsR.handleDnsRequest)
 
 	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)

+ 28 - 13
dns_server_test.go

@@ -1,46 +1,61 @@
 package nebula
 
 import (
+	"net/netip"
 	"testing"
 
 	"github.com/miekg/dns"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestParsequery(t *testing.T) {
-	//TODO: This test is basically pointless
+	l := logrus.New()
 	hostMap := &HostMap{}
-	ds := newDnsRecords(hostMap)
-	ds.Add("test.com.com", "1.2.3.4")
+	ds := newDnsRecords(l, &CertState{}, hostMap)
+	addrs := []netip.Addr{
+		netip.MustParseAddr("1.2.3.4"),
+		netip.MustParseAddr("1.2.3.5"),
+		netip.MustParseAddr("fd01::24"),
+		netip.MustParseAddr("fd01::25"),
+	}
+	ds.Add("test.com.com", addrs)
 
-	m := new(dns.Msg)
+	m := &dns.Msg{}
 	m.SetQuestion("test.com.com", dns.TypeA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
 
-	//parseQuery(m)
+	m = &dns.Msg{}
+	m.SetQuestion("test.com.com", dns.TypeAAAA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
 }
 
 func Test_getDnsServerAddr(t *testing.T) {
 	c := config.NewC(nil)
 
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "0.0.0.0",
 			"port": "1",
 		},
 	}
 	assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
 
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "::",
 			"port": "1",
 		},
 	}
 	assert.Equal(t, "[::]:1", getDnsServerAddr(c))
 
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "[::]",
 			"port": "1",
 		},
@@ -48,8 +63,8 @@ func Test_getDnsServerAddr(t *testing.T) {
 	assert.Equal(t, "[::]:1", getDnsServerAddr(c))
 
 	// Make sure whitespace doesn't mess us up
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"dns": map[interface{}]interface{}{
+	c.Settings["lighthouse"] = map[string]any{
+		"dns": map[string]any{
 			"host": "[::] ",
 			"port": "1",
 		},

File diff suppressed because it is too large
+ 394 - 173
e2e/handshakes_test.go


+ 0 - 118
e2e/helpers.go

@@ -1,118 +0,0 @@
-package e2e
-
-import (
-	"crypto/rand"
-	"io"
-	"net"
-	"time"
-
-	"github.com/slackhq/nebula/cert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-)
-
-// NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(cert.Curve_CURVE25519, priv)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, priv, pem
-}
-
-// NewTestCert will generate a signed certificate with the provided details.
-// Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		panic(err)
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	pub, rawPriv := x25519Keypair()
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           name,
-			Ips:            []*net.IPNet{ip},
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}

+ 92 - 51
e2e/helpers_test.go

@@ -6,8 +6,9 @@ package e2e
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
+	"strings"
 	"testing"
 	"time"
 
@@ -17,29 +18,47 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
+func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 
-	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
-	copy(vpnIpNet.IP, udpIp)
-	vpnIpNet.IP[1] += 128
-	udpAddr := net.UDPAddr{
-		IP:   udpIp,
-		Port: 4242,
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
+	}
+
+	var udpAddr netip.AddrPort
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
+		budpIp[1] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+	} else {
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
-	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
-	caB, err := caCrt.MarshalToPEM()
+	caB, err := caCrt.MarshalPEM()
 	if err != nil {
 		panic(err)
 	}
@@ -67,8 +86,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		//	"try_interval": "1s",
 		//},
 		"listen": m{
-			"host": udpAddr.IP.String(),
-			"port": udpAddr.Port,
+			"host": udpAddr.Addr().String(),
+			"port": udpAddr.Port(),
 		},
 		"logging": m{
 			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
@@ -81,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 	}
 
 	if overrides != nil {
-		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		final := m{}
+		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
 		if err != nil {
 			panic(err)
 		}
-		mc = overrides
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = final
 	}
 
 	cb, err := yaml.Marshal(mc)
@@ -102,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 	}
 
-	return control, vpnIpNet, &udpAddr, c
+	return control, vpnNetworks, udpAddr, c
 }
 
 type doneCb func()
@@ -123,64 +147,73 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 	}
 }
 
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
-	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
+	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 	// And once more from me to them
-	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
+	controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 
-func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
+	//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
+	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
-	hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
+	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
+	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
+	assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
 
-	assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
-	assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
-
-	assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
-	assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
+	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
+	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
 
 	// Check that our indexes match
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
+}
 
-	//TODO: Would be nice to assert this memory
-	//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
-	//	hBbyIndex := hmA.Indexes[hBinA.localIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
-	//	assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
-	//
-	//	//TODO: remote indexes are susceptible to collision
-	//	hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
-	//	assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
-	//}
-	//
-	//// Check hostmap indexes too
-	//checkIndexes("hmA", hmA, hBinA)
-	//checkIndexes("hmB", hmB, hAinB)
+func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	if toIp.Is6() {
+		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
+	} else {
+		assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
+	}
 }
 
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
+func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
+	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+	assert.NotNil(t, v6, "No ipv6 data found")
+
+	assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
+
+	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	assert.NotNil(t, udp, "No udp data found")
+
+	assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
+	assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
+
+	data := packet.ApplicationLayer()
+	assert.NotNil(t, data)
+	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
+}
+
+func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
 
-	assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
-	assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
+	assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
 
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
 	assert.NotNil(t, udp, "No udp data found")
@@ -193,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, from
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 
+func getAddrs(ns []netip.Prefix) []netip.Addr {
+	var a []netip.Addr
+	for _, n := range ns {
+		a = append(a, n.Addr())
+	}
+	return a
+}
+
 func NewTestLogger() *logrus.Logger {
 	l := logrus.New()
 

+ 9 - 8
e2e/router/hostmap.go

@@ -5,11 +5,11 @@ package router
 
 import (
 	"fmt"
+	"net/netip"
 	"sort"
 	"strings"
 
 	"github.com/slackhq/nebula"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type edge struct {
@@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	var lines []string
 	var globalLines []*edge
 
-	clusterName := strings.Trim(c.GetCert().Details.Name, " ")
-	clusterVpnIp := c.GetCert().Details.Ips[0].IP
+	crt := c.GetCertState().GetDefaultCertificate()
+	clusterName := strings.Trim(crt.Name(), " ")
+	clusterVpnIp := crt.Networks()[0].Addr()
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 
 	hm := c.GetHostmap()
@@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	for _, idx := range indexes {
 		hi, ok := hm.Indexes[idx]
 		if ok {
-			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
-			remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
+			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
+			remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			_ = hi
 		}
@@ -118,14 +119,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	return r, globalLines
 }
 
-func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
-	keys := make([]iputil.VpnIp, 0, len(hosts))
+func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
+	keys := make([]netip.Addr, 0, len(hosts))
 	for key := range hosts {
 		keys = append(keys, key)
 	}
 
 	sort.SliceStable(keys, func(i, j int) bool {
-		return keys[i] > keys[j]
+		return keys[i].Compare(keys[j]) > 0
 	})
 
 	return keys

+ 72 - 77
e2e/router/router.go

@@ -6,13 +6,12 @@ package router
 import (
 	"context"
 	"fmt"
-	"net"
+	"net/netip"
 	"os"
 	"path/filepath"
 	"reflect"
+	"regexp"
 	"sort"
-	"strconv"
-	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -21,7 +20,6 @@ import (
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"golang.org/x/exp/maps"
 )
@@ -29,18 +27,18 @@ import (
 type R struct {
 	// Simple map of the ip:port registered on a control to the control
 	// Basically a router, right?
-	controls map[string]*nebula.Control
+	controls map[netip.AddrPort]*nebula.Control
 
 	// A map for inbound packets for a control that doesn't know about this address
-	inNat map[string]*nebula.Control
+	inNat map[netip.AddrPort]*nebula.Control
 
 	// A last used map, if an inbound packet hit the inNat map then
 	// all return packets should use the same last used inbound address for the outbound sender
 	// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
-	outNat map[string]net.UDPAddr
+	outNat map[string]netip.AddrPort
 
 	// A map of vpn ip to the nebula control it belongs to
-	vpnControls map[iputil.VpnIp]*nebula.Control
+	vpnControls map[netip.Addr]*nebula.Control
 
 	ignoreFlows []ignoreFlow
 	flow        []flowEntry
@@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	}
 
 	r := &R{
-		controls:     make(map[string]*nebula.Control),
-		vpnControls:  make(map[iputil.VpnIp]*nebula.Control),
-		inNat:        make(map[string]*nebula.Control),
-		outNat:       make(map[string]net.UDPAddr),
+		controls:     make(map[netip.AddrPort]*nebula.Control),
+		vpnControls:  make(map[netip.Addr]*nebula.Control),
+		inNat:        make(map[netip.AddrPort]*nebula.Control),
+		outNat:       make(map[string]netip.AddrPort),
 		flow:         []flowEntry{},
 		ignoreFlows:  []ignoreFlow{},
 		fn:           filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -135,10 +133,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	for _, c := range controls {
 		addr := c.GetUDPAddr()
 		if _, ok := r.controls[addr]; ok {
-			panic("Duplicate listen address: " + addr)
+			panic("Duplicate listen address: " + addr.String())
+		}
+
+		for _, vpnAddr := range c.GetVpnAddrs() {
+			r.vpnControls[vpnAddr] = c
 		}
 
-		r.vpnControls[c.GetVpnIp()] = c
 		r.controls[addr] = c
 	}
 
@@ -165,13 +166,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 // It does not look at the addr attached to the instance.
 // If a route is used, this will behave like a NAT for the return path.
 // Rewriting the source ip:port to what was last sent to from the origin
-func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
+func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
 	r.Lock()
 	defer r.Unlock()
 
-	inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
+	inAddr := netip.AddrPortFrom(ip, port)
 	if _, ok := r.inNat[inAddr]; ok {
-		panic("Duplicate listen address inNat: " + inAddr)
+		panic("Duplicate listen address inNat: " + inAddr.String())
 	}
 	r.inNat[inAddr] = c
 }
@@ -198,7 +199,7 @@ func (r *R) renderFlow() {
 		panic(err)
 	}
 
-	var participants = map[string]struct{}{}
+	var participants = map[netip.AddrPort]struct{}{}
 	var participantsVals []string
 
 	fmt.Fprintln(f, "```mermaid")
@@ -215,11 +216,11 @@ func (r *R) renderFlow() {
 			continue
 		}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr, ":", "-", 1)
+		sanAddr := normalizeName(addr.String())
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
-			sanAddr, e.packet.from.GetVpnIp(), sanAddr,
+			sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
 		)
 	}
 
@@ -252,9 +253,9 @@ func (r *R) renderFlow() {
 
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
+				normalizeName(p.from.GetUDPAddr().String()),
 				line,
-				strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+				normalizeName(p.to.GetUDPAddr().String()),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 		}
@@ -269,6 +270,11 @@ func (r *R) renderFlow() {
 	}
 }
 
+func normalizeName(s string) string {
+	rx := regexp.MustCompile("[\\[\\]\\:]")
+	return rx.ReplaceAllLiteralString(s, "_")
+}
+
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@@ -305,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp() > c[j].GetVpnIp()
+		return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
 	})
 
 	s := renderHostmaps(c...)
@@ -420,13 +426,12 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 
 		// Nope, lets push the sender along
 		case p := <-udpTx:
-			outAddr := sender.GetUDPAddr()
 			r.Lock()
-			inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-			c := r.getControl(outAddr, inAddr, p)
+			a := sender.GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 				r.Unlock()
-				panic("No control for udp tx")
+				panic("No control for udp tx " + a.String())
 			}
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			c.InjectUDPPacket(p)
@@ -479,13 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
-			outAddr := cm[x].GetUDPAddr()
-
-			inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-			c := r.getControl(outAddr, inAddr, p)
+			a := cm[x].GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 				r.Unlock()
-				panic("No control for udp tx")
+				panic(fmt.Sprintf("No control for udp tx %s", p.To))
 			}
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			c.InjectUDPPacket(p)
@@ -509,12 +512,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 			panic(err)
 		}
 
-		outAddr := sender.GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
 		if receiver == nil {
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't RouteExitFunc for host: " + p.To.String())
 		}
 
 		e := whatDo(p, receiver)
@@ -590,13 +591,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
 	if finish == KeepRouting {
 		finish = RouteAndExit
 	}
 
 	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
-		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
+		if p.To == toAddr {
 			return finish
 		}
 
@@ -630,13 +631,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
 		r.Lock()
 
 		p := rx.Interface().(*udp.Packet)
-
-		outAddr := cm[x].GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 		if receiver == nil {
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't RouteForAllExitFunc for host: " + p.To.String())
 		}
 
 		e := whatDo(p, receiver)
@@ -697,12 +695,10 @@ func (r *R) FlushAll() {
 
 		p := rx.Interface().(*udp.Packet)
 
-		outAddr := cm[x].GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 		if receiver == nil {
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't FlushAll for host: " + p.To.String())
 		}
 		r.Unlock()
 	}
@@ -710,28 +706,14 @@ func (r *R) FlushAll() {
 
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
-	if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
-		p.FromIp = newAddr.IP
-		p.FromPort = uint16(newAddr.Port)
+func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
+	if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
+		p.From = newAddr
 	}
 
 	c, ok := r.inNat[toAddr]
 	if ok {
-		sHost, sPort, err := net.SplitHostPort(toAddr)
-		if err != nil {
-			panic(err)
-		}
-
-		port, err := strconv.Atoi(sPort)
-		if err != nil {
-			panic(err)
-		}
-
-		r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
-			IP:   net.ParseIP(sHost),
-			Port: port,
-		}
+		r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
 		return c
 	}
 
@@ -739,29 +721,42 @@ func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
 }
 
 func (r *R) formatUdpPacket(p *packet) string {
-	packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
-	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
-	if v4 == nil {
-		panic("not an ipv4 packet")
+	var packet gopacket.Packet
+	var srcAddr netip.Addr
+
+	packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
+	if packet.ErrorLayer() == nil {
+		v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
+	} else {
+		packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
+		v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
 	}
 
 	from := "unknown"
-	if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok {
-		from = c.GetUDPAddr()
+	if c, ok := r.vpnControls[srcAddr]; ok {
+		from = c.GetUDPAddr().String()
 	}
 
-	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
-	if udp == nil {
+	udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	if udpLayer == nil {
 		panic("not a udp packet")
 	}
 
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
-		udp.SrcPort,
-		udp.DstPort,
+		normalizeName(from),
+		normalizeName(p.to.GetUDPAddr().String()),
+		udpLayer.SrcPort,
+		udpLayer.DstPort,
 		string(data.Payload()),
 	)
 }

+ 44 - 13
examples/config.yml

@@ -13,6 +13,12 @@ pki:
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   #disconnect_invalid: true
 
+  # default_version controls which certificate version is used in handshakes.
+  # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
+  # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
+  # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
+  # default_version: 1
+
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # The syntax is:
@@ -120,8 +126,8 @@ lighthouse:
 # Port Nebula will be listening on. The default here is 4242. For a lighthouse node, the port should be defined,
 # however using port 0 will dynamically assign a port and is recommended for roaming nodes.
 listen:
-  # To listen on both any ipv4 and ipv6 use "::"
-  host: 0.0.0.0
+  # To listen on only ipv4, use "0.0.0.0"
+  host: "::"
   port: 4242
   # Sets the max number of packets to pull from the kernel for each syscall (under systems that support recvmmsg)
   # default is 64, does not support reload
@@ -138,6 +144,11 @@ listen:
   # valid values: always, never, private
   # This setting is reloadable.
   #send_recv_error: always
+  # The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier.
+  # This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes,
+  # allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set.
+  # This setting is reloadable.
+  #so_mark: 0
 
 # Routines is the number of thread pairs to run that consume from the tun and UDP queues.
 # Currently, this defaults to 1 which means we have 1 tun queue reader and 1
@@ -228,7 +239,28 @@ tun:
 
   # Unsafe routes allows you to route traffic over nebula to non-nebula nodes
   # Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
-  # NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
+  # Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
+  # NOTES:
+  # * You will only see a single gateway in the routing table if you are not on linux
+  # * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
+  #
+  # unsafe_routes:
+  # # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
+  # - route: 192.168.87.0/24
+  #   via:
+  #     - gateway: 10.0.0.1
+  #     - gateway: 10.0.0.2
+  #     - gateway: 10.0.0.3
+  # # Multiple gateways with a weight, this will balance traffic accordingly
+  # - route: 192.168.87.0/24
+  #   via:
+  #     - gateway: 10.0.0.1
+  #       weight: 10
+  #     - gateway: 10.0.0.2
+  #       weight: 5
+  #
+  # NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
+  # `via`: single node or list of gateways to use for this route
   # `mtu`: will default to tun mtu if this option is not specified
   # `metric`: will default to 0 if this option is not specified
   # `install`: will default to true, controls whether this route is installed in the systems routing table.
@@ -244,7 +276,6 @@ tun:
   # in nebula configuration files. Default false, not reloadable.
   #use_system_route_table: false
 
-# TODO
 # Configure logging level
 logging:
   # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
@@ -315,11 +346,11 @@ firewall:
   outbound_action: drop
   inbound_action: drop
 
-  # Controls the default value for local_cidr. Default is true, will be deprecated after v1.9 and defaulted to false.
-  # This setting only affects nebula hosts with subnets encoded in their certificate. A nebula host acting as an
-  # unsafe router with `default_local_cidr_any: true` will expose their unsafe routes to every inbound rule regardless
-  # of the actual destination for the packet. Setting this to false requires each inbound rule to contain a `local_cidr`
-  # if the intention is to allow traffic to flow to an unsafe route.
+  # THIS FLAG IS DEPRECATED AND WILL BE REMOVED IN A FUTURE RELEASE. (Defaults to false.)
+  # This setting only affects nebula hosts exposing unsafe_routes. When set to false, each inbound rule must contain a
+  # `local_cidr` if the intention is to allow traffic to flow to an unsafe route. When set to true, every firewall rule
+  # will apply to all configured unsafe_routes regardless of the actual destination of the packet, unless `local_cidr`
+  # is explicitly defined. This is usually not the desired behavior and should be avoided!
   #default_local_cidr_any: false
 
   conntrack:
@@ -336,10 +367,10 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
-  #      Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
-  #      if `default_local_cidr_any` is false, otherwise its `any`.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
+  #     By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
+  #     If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
 

+ 18 - 9
examples/go_service/main.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"fmt"
 	"log"
+	"net"
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/service"
@@ -54,16 +55,16 @@ pki:
   cert: /home/rice/Developer/nebula-config/app.crt
   key: /home/rice/Developer/nebula-config/app.key
 `
-	var config config.C
-	if err := config.LoadString(configStr); err != nil {
+	var cfg config.C
+	if err := cfg.LoadString(configStr); err != nil {
 		return err
 	}
-	service, err := service.New(&config)
+	svc, err := service.New(&cfg)
 	if err != nil {
 		return err
 	}
 
-	ln, err := service.Listen("tcp", ":1234")
+	ln, err := svc.Listen("tcp", ":1234")
 	if err != nil {
 		return err
 	}
@@ -73,16 +74,24 @@ pki:
 			log.Printf("accept error: %s", err)
 			break
 		}
-		defer conn.Close()
+		defer func(conn net.Conn) {
+			_ = conn.Close()
+		}(conn)
 
 		log.Printf("got connection")
 
-		conn.Write([]byte("hello world\n"))
+		_, err = conn.Write([]byte("hello world\n"))
+		if err != nil {
+			log.Printf("write error: %s", err)
+		}
 
 		scanner := bufio.NewScanner(conn)
 		for scanner.Scan() {
 			message := scanner.Text()
-			fmt.Fprintf(conn, "echo: %q\n", message)
+			_, err = fmt.Fprintf(conn, "echo: %q\n", message)
+			if err != nil {
+				log.Printf("write error: %s", err)
+			}
 			log.Printf("got message %q", message)
 		}
 
@@ -92,8 +101,8 @@ pki:
 		}
 	}
 
-	service.Close()
-	if err := service.Wait(); err != nil {
+	_ = svc.Close()
+	if err := svc.Wait(); err != nil {
 		return err
 	}
 	return nil

+ 128 - 110
firewall.go

@@ -6,22 +6,22 @@ import (
 	"errors"
 	"fmt"
 	"hash/fnv"
-	"net"
+	"net/netip"
 	"reflect"
 	"strconv"
 	"strings"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 )
 
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
 }
 
 type conn struct {
@@ -50,10 +50,13 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 
-	// Used to ensure we don't emit local packets for ips we don't own
-	localIps     *cidr.Tree4[struct{}]
-	assignedCIDR *net.IPNet
-	hasSubnets   bool
+	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
+	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
+	routableNetworks *bart.Table[struct{}]
+
+	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
+	assignedNetworks  []netip.Prefix
+	hasUnsafeNetworks bool
 
 	rules        string
 	rulesVersion uint16
@@ -66,9 +69,9 @@ type Firewall struct {
 }
 
 type firewallMetrics struct {
-	droppedLocalIP  metrics.Counter
-	droppedRemoteIP metrics.Counter
-	droppedNoRule   metrics.Counter
+	droppedLocalAddr  metrics.Counter
+	droppedRemoteAddr metrics.Counter
+	droppedNoRule     metrics.Counter
 }
 
 type FirewallConntrack struct {
@@ -107,7 +110,7 @@ type FirewallRule struct {
 	Any    *firewallLocalCIDR
 	Hosts  map[string]*firewallLocalCIDR
 	Groups []*firewallGroups
-	CIDR   *cidr.Tree4[*firewallLocalCIDR]
+	CIDR   *bart.Table[*firewallLocalCIDR]
 }
 
 type firewallGroups struct {
@@ -121,85 +124,92 @@ type firewallPort map[int32]*FirewallCA
 
 type firewallLocalCIDR struct {
 	Any       bool
-	LocalCIDR *cidr.Tree4[struct{}]
+	LocalCIDR *bart.Table[struct{}]
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
-func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
+// The certificate provided should be the highest version loaded in memory.
+func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
 	//TODO: error on 0 duration
-	var min, max time.Duration
+	var tmin, tmax time.Duration
 
 	if tcpTimeout < UDPTimeout {
-		min = tcpTimeout
-		max = UDPTimeout
+		tmin = tcpTimeout
+		tmax = UDPTimeout
 	} else {
-		min = UDPTimeout
-		max = tcpTimeout
+		tmin = UDPTimeout
+		tmax = tcpTimeout
 	}
 
-	if defaultTimeout < min {
-		min = defaultTimeout
-	} else if defaultTimeout > max {
-		max = defaultTimeout
+	if defaultTimeout < tmin {
+		tmin = defaultTimeout
+	} else if defaultTimeout > tmax {
+		tmax = defaultTimeout
 	}
 
-	localIps := cidr.NewTree4[struct{}]()
-	var assignedCIDR *net.IPNet
-	for _, ip := range c.Details.Ips {
-		ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
-		localIps.AddCIDR(ipNet, struct{}{})
-
-		if assignedCIDR == nil {
-			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = ipNet
-		}
+	routableNetworks := new(bart.Table[struct{}])
+	var assignedNetworks []netip.Prefix
+	for _, network := range c.Networks() {
+		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
+		routableNetworks.Insert(nprefix, struct{}{})
+		assignedNetworks = append(assignedNetworks, network)
 	}
 
-	for _, n := range c.Details.Subnets {
-		localIps.AddCIDR(n, struct{}{})
+	hasUnsafeNetworks := false
+	for _, n := range c.UnsafeNetworks() {
+		routableNetworks.Insert(n, struct{}{})
+		hasUnsafeNetworks = true
 	}
 
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 			syncMutex:  newSyncMutex("firewall-conntrack"),
 			Conns:      make(map[firewall.Packet]*conn),
-			TimerWheel: NewTimerWheel[firewall.Packet](min, max),
+			TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
 		},
-		InRules:        newFirewallTable(),
-		OutRules:       newFirewallTable(),
-		TCPTimeout:     tcpTimeout,
-		UDPTimeout:     UDPTimeout,
-		DefaultTimeout: defaultTimeout,
-		localIps:       localIps,
-		assignedCIDR:   assignedCIDR,
-		hasSubnets:     len(c.Details.Subnets) > 0,
-		l:              l,
+		InRules:           newFirewallTable(),
+		OutRules:          newFirewallTable(),
+		TCPTimeout:        tcpTimeout,
+		UDPTimeout:        UDPTimeout,
+		DefaultTimeout:    defaultTimeout,
+		routableNetworks:  routableNetworks,
+		assignedNetworks:  assignedNetworks,
+		hasUnsafeNetworks: hasUnsafeNetworks,
+		l:                 l,
 
 		incomingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
 		},
 		outgoingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
 		},
 	}
 }
 
-func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
+	certificate := cs.getCertificate(cert.Version2)
+	if certificate == nil {
+		certificate = cs.getCertificate(cert.Version1)
+	}
+
+	if certificate == nil {
+		panic("No certificate available to reconfigure the firewall")
+	}
+
 	fw := NewFirewall(
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
-		nc,
+		certificate,
 		//TODO: max_connections
 	)
 
-	//TODO: Flip to false after v1.9 release
-	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
+	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
 
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	switch inboundAction {
@@ -237,15 +247,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
 }
 
 // AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// https://github.com/golang/go/issues/14131
 	sIp := ""
-	if ip != nil {
+	if ip.IsValid() {
 		sIp = ip.String()
 	}
 	lIp := ""
-	if localIp != nil {
+	if localIp.IsValid() {
 		lIp = localIp.String()
 	}
 
@@ -279,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		fp = ft.TCP
 	case firewall.ProtoUDP:
 		fp = ft.UDP
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		fp = ft.ICMP
 	case firewall.ProtoAny:
 		fp = ft.AnyProto
@@ -321,7 +331,7 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 		return nil
 	}
 
-	rs, ok := r.([]interface{})
+	rs, ok := r.([]any)
 	if !ok {
 		return fmt.Errorf("%s failed to parse, should be an array of rules", table)
 	}
@@ -382,17 +392,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 		}
 
-		var cidr *net.IPNet
+		var cidr netip.Prefix
 		if r.Cidr != "" {
-			_, cidr, err = net.ParseCIDR(r.Cidr)
+			cidr, err = netip.ParsePrefix(r.Cidr)
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
 			}
 		}
 
-		var localCidr *net.IPNet
+		var localCidr netip.Prefix
 		if r.LocalCidr != "" {
-			_, localCidr, err = net.ParseCIDR(r.LocalCidr)
+			localCidr, err = netip.ParsePrefix(r.LocalCidr)
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
 			}
@@ -413,31 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
+func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(fp, h, caPool, localCache) {
 		return nil
 	}
 
 	// Make sure remote address matches nebula certificate
-	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		ok, _ := remoteCidr.Contains(fp.RemoteIP)
+	if h.networks != nil {
+		_, ok := h.networks.Lookup(fp.RemoteAddr)
 		if !ok {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.vpnIp {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	ok, _ := f.localIps.Contains(fp.LocalIP)
+	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
 	if !ok {
-		f.metrics(incoming).droppedLocalIP.Inc(1)
+		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 	}
 
@@ -482,7 +492,7 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 
-func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
+func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
 	if localCache != nil {
 		if _, ok := localCache[fp]; ok {
 			return true
@@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Caller must own the connMutex lock!
 func (f *Firewall) evict(p firewall.Packet) {
-	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
 	conntrack := f.Conntrack
 	t, ok := conntrack.Conns[p]
@@ -610,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) {
 	delete(conntrack.Conns, p)
 }
 
-func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 		return true
 	}
@@ -624,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 		}
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 		}
@@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 	return false
 }
 
-func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 	}
@@ -654,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
 	return nil
 }
 
-func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	// We don't have any allowed ports, bail
 	if fp == nil {
 		return false
@@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 
-func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
 	fr := func() *FirewallRule {
 		return &FirewallRule{
 			Hosts:  make(map[string]*firewallLocalCIDR),
 			Groups: make([]*firewallGroups, 0),
-			CIDR:   cidr.NewTree4[*firewallLocalCIDR](),
+			CIDR:   new(bart.Table[*firewallLocalCIDR]),
 		}
 	}
 
@@ -717,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
 	return nil
 }
 
-func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	if fc == nil {
 		return false
 	}
@@ -726,24 +735,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 		return true
 	}
 
-	if t, ok := fc.CAShas[c.Details.Issuer]; ok {
+	if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok {
 		if t.match(p, c) {
 			return true
 		}
 	}
 
-	s, err := caPool.GetCAForCert(c)
+	s, err := caPool.GetCAForCert(c.Certificate)
 	if err != nil {
 		return false
 	}
 
-	return fc.CANames[s.Details.Name].match(p, c)
+	return fc.CANames[s.Certificate.Name()].match(p, c)
 }
 
-func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
 	flc := func() *firewallLocalCIDR {
 		return &firewallLocalCIDR{
-			LocalCIDR: cidr.NewTree4[struct{}](),
+			LocalCIDR: new(bart.Table[struct{}]),
 		}
 	}
 
@@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
 		fr.Hosts[host] = nlc
 	}
 
-	if ip != nil {
-		_, nlc := fr.CIDR.GetCIDR(ip)
+	if ip.IsValid() {
+		nlc, _ := fr.CIDR.Get(ip)
 		if nlc == nil {
 			nlc = flc()
 		}
@@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
 		if err != nil {
 			return err
 		}
-		fr.CIDR.AddCIDR(ip, nlc)
+		fr.CIDR.Insert(ip, nlc)
 	}
 
 	return nil
 }
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
-	if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
+	if len(groups) == 0 && host == "" && !ip.IsValid() {
 		return true
 	}
 
@@ -810,14 +819,14 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 		return true
 	}
 
-	if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
+	if ip.IsValid() && ip.Bits() == 0 {
 		return true
 	}
 
 	return false
 }
 
-func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool {
 	if fr == nil {
 		return false
 	}
@@ -832,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		found := false
 
 		for _, g := range sg.Groups {
-			if _, ok := c.Details.InvertedGroups[g]; !ok {
+			if _, ok := c.InvertedGroups[g]; !ok {
 				found = false
 				break
 			}
@@ -846,35 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 	}
 
 	if fr.Hosts != nil {
-		if flc, ok := fr.Hosts[c.Details.Name]; ok {
+		if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
 			if flc.match(p, c) {
 				return true
 			}
 		}
 	}
 
-	return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
-		return flc.match(p, c)
-	})
+	for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
+		if v.match(p, c) {
+			return true
+		}
+	}
+
+	return false
 }
 
-func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
-	if localIp == nil {
-		if !f.hasSubnets || f.defaultLocalCIDRAny {
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
+	if !localIp.IsValid() {
+		if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
 			flc.Any = true
 			return nil
 		}
 
-		localIp = f.assignedCIDR
-	} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
+		for _, network := range f.assignedNetworks {
+			flc.LocalCIDR.Insert(network, struct{}{})
+		}
+		return nil
+
+	} else if localIp.Bits() == 0 {
 		flc.Any = true
+		return nil
 	}
 
-	flc.LocalCIDR.AddCIDR(localIp, struct{}{})
+	flc.LocalCIDR.Insert(localIp, struct{}{})
 	return nil
 }
 
-func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
 	if flc == nil {
 		return false
 	}
@@ -883,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
 		return true
 	}
 
-	ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
 	return ok
 }
 
@@ -900,15 +918,15 @@ type rule struct {
 	CASha     string
 }
 
-func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
+func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
 	r := rule{}
 
-	m, ok := p.(map[interface{}]interface{})
+	m, ok := p.(map[string]any)
 	if !ok {
 		return r, errors.New("could not parse rule")
 	}
 
-	toString := func(k string, m map[interface{}]interface{}) string {
+	toString := func(k string, m map[string]any) string {
 		v, ok := m[k]
 		if !ok {
 			return ""
@@ -926,7 +944,7 @@ func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, er
 	r.CASha = toString("ca_sha", m)
 
 	// Make sure group isn't an array
-	if v, ok := m["group"].([]interface{}); ok {
+	if v, ok := m["group"].([]any); ok {
 		if len(v) > 1 {
 			return r, errors.New("group should contain a single value, an array with more than one entry was provided")
 		}

+ 13 - 13
firewall/packet.go

@@ -3,25 +3,25 @@ package firewall
 import (
 	"encoding/json"
 	"fmt"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 const (
-	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	ProtoTCP  = 6
-	ProtoUDP  = 17
-	ProtoICMP = 1
+	ProtoAny    = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP    = 6
+	ProtoUDP    = 17
+	ProtoICMP   = 1
+	ProtoICMPv6 = 58
 
 	PortAny      = 0  // Special value for matching `port: any`
 	PortFragment = -1 // Special value for matching `port: fragment`
 )
 
 type Packet struct {
-	LocalIP    iputil.VpnIp
-	RemoteIP   iputil.VpnIp
+	LocalAddr  netip.Addr
+	RemoteAddr netip.Addr
 	LocalPort  uint16
 	RemotePort uint16
 	Protocol   uint8
@@ -30,8 +30,8 @@ type Packet struct {
 
 func (fp *Packet) Copy() *Packet {
 	return &Packet{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
+		LocalAddr:  fp.LocalAddr,
+		RemoteAddr: fp.RemoteAddr,
 		LocalPort:  fp.LocalPort,
 		RemotePort: fp.RemotePort,
 		Protocol:   fp.Protocol,
@@ -52,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 	}
 	return json.Marshal(m{
-		"LocalIP":    fp.LocalIP.String(),
-		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalAddr":  fp.LocalAddr.String(),
+		"RemoteAddr": fp.RemoteAddr.String(),
 		"LocalPort":  fp.LocalPort,
 		"RemotePort": fp.RemotePort,
 		"Protocol":   proto,

File diff suppressed because it is too large
+ 259 - 317
firewall_test.go


+ 25 - 22
go.mod

@@ -1,55 +1,58 @@
 module github.com/slackhq/nebula
 
-go 1.22.0
+go 1.23.6
 
-toolchain go1.22.2
+toolchain go1.24.1
 
 require (
-	dario.cat/mergo v1.0.0
+	dario.cat/mergo v1.0.1
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
 	github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
+	github.com/gaissmai/bart v0.20.1
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.59
+	github.com/miekg/dns v1.1.64
+	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.19.0
+	github.com/prometheus/client_golang v1.21.1
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
-	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
-	github.com/stretchr/testify v1.9.0
+	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
+	github.com/stretchr/testify v1.10.0
 	github.com/timandy/routine v1.1.1
-	github.com/vishvananda/netlink v1.2.1-beta.2
-	golang.org/x/crypto v0.23.0
+	github.com/vishvananda/netlink v1.3.0
+	golang.org/x/crypto v0.36.0
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
-	golang.org/x/net v0.25.0
-	golang.org/x/sync v0.7.0
-	golang.org/x/sys v0.20.0
-	golang.org/x/term v0.20.0
+	golang.org/x/net v0.38.0
+	golang.org/x/sync v0.12.0
+	golang.org/x/sys v0.31.0
+	golang.org/x/term v0.30.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.34.1
-	gopkg.in/yaml.v2 v2.4.0
+	google.golang.org/protobuf v1.36.6
+	gopkg.in/yaml.v3 v3.0.1
 	gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
 )
 
 require (
 	github.com/beorn7/perks v1.0.1 // indirect
-	github.com/cespare/xxhash/v2 v2.2.0 // indirect
+	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/google/btree v1.1.2 // indirect
+	github.com/klauspost/compress v1.17.11 // indirect
+	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.5.0 // indirect
-	github.com/prometheus/common v0.48.0 // indirect
-	github.com/prometheus/procfs v0.12.0 // indirect
+	github.com/prometheus/client_model v0.6.1 // indirect
+	github.com/prometheus/common v0.62.0 // indirect
+	github.com/prometheus/procfs v0.15.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
-	golang.org/x/mod v0.16.0 // indirect
+	golang.org/x/mod v0.23.0 // indirect
 	golang.org/x/time v0.5.0 // indirect
-	golang.org/x/tools v0.19.0 // indirect
-	gopkg.in/yaml.v3 v3.0.1 // indirect
+	golang.org/x/tools v0.30.0 // indirect
 )

+ 48 - 41
go.sum

@@ -1,6 +1,6 @@
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
-dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
+dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
+dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -15,8 +15,8 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
-github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
+github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc h1:6e91sWiDE69Jl0WUsY/LvTCBPRBe6b2j8H7W96JGJ4s=
 github.com/clarkmcc/go-dag v0.0.0-20220908000337-9c3ba5b365fc/go.mod h1:RGIcF96ORCYAsdz60Ou9mPBNa4+DjoQFS8nelPniFoY=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
@@ -26,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/gaissmai/bart v0.20.1 h1:igNss0zDsSY8e+ophKgD9KJVPKBOo7uSVjyKCL7nIzo=
+github.com/gaissmai/bart v0.20.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -68,6 +70,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
 github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
+github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
+github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -78,13 +82,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
-github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
+github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ=
+github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
 github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg=
@@ -98,24 +108,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU=
-github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k=
+github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
+github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
-github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
+github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
+github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
-github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
+github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
+github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
-github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
+github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
+github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -127,21 +137,20 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
 github.com/timandy/routine v1.1.1 h1:6/Z7qLFZj3GrzuRksBFzIG8YGUh8CLhjnnMePBQTrEI=
 github.com/timandy/routine v1.1.1/go.mod h1:OZHPOKSvqL/ZvqXFkNZyit0xIVelERptYXdAHH00adQ=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
-github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
-github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
-github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
+github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
+github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -151,16 +160,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
-golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
 golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
 golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
-golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
+golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
+golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
 golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
 golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -171,8 +180,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
-golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
+golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
+golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -180,30 +189,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
-golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
+golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
-golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
-golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
+golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
+golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -214,8 +223,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
 golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
 golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
 golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
-golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
+golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
+golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -234,8 +243,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
-google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
+google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
+google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -246,8 +255,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
-gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 303 - 117
handshake_ix.go

@@ -1,13 +1,14 @@
 package nebula
 
 import (
+	"net/netip"
+	"slices"
 	"time"
 
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 // NOISE IX Handshakes
@@ -17,40 +18,70 @@ import (
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 	}
 
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
-	hh.hostinfo.ConnectionState = ci
+	// If we're connecting to a v6 address we must use a v2 cert
+	cs := f.pki.getCertState()
+	v := cs.defaultVersion
+	for _, a := range hh.hostinfo.vpnAddrs {
+		if a.Is6() {
+			v = cert.Version2
+			break
+		}
+	}
 
-	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hh.hostinfo.localIndexId,
-		Time:           uint64(time.Now().UnixNano()),
-		Cert:           certState.RawCertificateNoKey,
+	crt := cs.getCertificate(v)
+	if crt == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate is available")
+		return false
 	}
 
-	hsBytes := []byte{}
+	crtHs := cs.getHandshakeBytes(v)
+	if crtHs == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Failed to create connection state")
+		return false
+	}
+	hh.hostinfo.ConnectionState = ci
 
 	hs := &NebulaHandshake{
-		Details: hsProto,
+		Details: &NebulaHandshakeDetails{
+			InitiatorIndex: hh.hostinfo.localIndexId,
+			Time:           uint64(time.Now().UnixNano()),
+			Cert:           crtHs,
+			CertVersion:    uint32(v),
+		},
 	}
-	hsBytes, err = hs.Marshal()
 
+	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("certVersion", v).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 	}
 
 	h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
-	ci.messageCounter.Add(1)
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 	}
@@ -64,67 +95,147 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	return true
 }
 
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
+func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
+	cs := f.pki.getCertState()
+	crt := cs.GetDefaultCertificate()
+	if crt == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", cs.defaultVersion).
+			Error("Unable to handshake with host because no certificate is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to create connection state")
+		return
+	}
+
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to call noise.ReadMessage")
 		return
 	}
 
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
-	/*
-		l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
-	*/
 	if err != nil || hs.Details == nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed unmarshal handshake message")
+		return
+	}
+
+	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
+	if err != nil {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Info("Handshake did not contain a certificate")
 		return
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
 	if err != nil {
+		fp, err := rc.Fingerprint()
+		if err != nil {
+			fp = "<error generating certificate fingerprint>"
+		}
+
 		e := f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("certVpnNetworks", rc.Networks()).
+			WithField("certFingerprint", fp)
 
-		if f.l.Level > logrus.DebugLevel {
-			e = e.WithField("cert", remoteCert)
+		if f.l.Level >= logrus.DebugLevel {
+			e = e.WithField("cert", rc)
 		}
 
 		e.Info("Invalid certificate from host")
 		return
 	}
-	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
-	certName := remoteCert.Details.Name
-	fingerprint, _ := remoteCert.Sha256Sum()
-	issuer := remoteCert.Details.Issuer
 
-	if vpnIp == f.myVpnIp {
-		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	if remoteCert.Certificate.Version() != ci.myCert.Version() {
+		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
+		rc := cs.getCertificate(remoteCert.Certificate.Version())
+		if rc == nil {
+			f.l.WithError(err).WithField("udpAddr", addr).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
+				Info("Unable to handshake with host due to missing certificate version")
+			return
+		}
+
+		// Record the certificate we are actually using
+		ci.myCert = rc
+	}
+
+	if len(remoteCert.Certificate.Networks()) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("cert", remoteCert).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Info("No networks in certificate")
+		return
+	}
+
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	certName := remoteCert.Certificate.Name()
+	certVersion := remoteCert.Certificate.Version()
+	fingerprint := remoteCert.Fingerprint
+	issuer := remoteCert.Certificate.Issuer()
+
+	for _, network := range remoteCert.Certificate.Networks() {
+		vpnAddr := network.Addr()
+		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
+		if found {
+			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("certVersion", certVersion).
+				WithField("fingerprint", fingerprint).
+				WithField("issuer", issuer).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			return
+		}
+
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
 		return
 	}
 
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if addr.IsValid() {
+		// addr can be invalid when the tunnel is being relayed.
+		// We only want to apply the remote allow list for direct tunnels here
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
 
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index")
@@ -136,19 +247,20 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		vpnIp:             vpnIp,
+		vpnAddrs:          vpnAddrs,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
-			syncRWMutex:   newSyncRWMutex("relay-state"),
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			syncRWMutex:    newSyncRWMutex("relay-state"),
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}
 
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
+		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -156,14 +268,29 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		Info("Handshake message received")
 
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = certState.RawCertificateNoKey
+	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
+	if hs.Details.Cert == nil {
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("certVersion", certVersion).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("certVersion", ci.myCert.Version()).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return
+	}
+
+	hs.Details.CertVersion = uint32(ci.myCert.Version())
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
 	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
@@ -173,15 +300,17 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Noise did not arrive at a key")
@@ -204,9 +333,9 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
-	hostinfo.CreateRemoteCIDR(remoteCert)
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -216,19 +345,19 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
-				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-			if addr != nil {
+			if addr.IsValid() {
 				err := f.outside.WriteTo(msg, addr)
 				if err != nil {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 				} else {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 				}
@@ -238,17 +367,18 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 				}
-				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
+				f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 				return
 			}
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
 				WithField("fingerprint", fingerprint).
@@ -258,24 +388,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 				Info("Handshake too old")
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -287,19 +419,21 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 
 	// Do the send
 	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-	if addr != nil {
+	if addr.IsValid() {
 		err = f.outside.WriteTo(msg, addr)
 		if err != nil {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake")
 		} else {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -311,10 +445,14 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 		}
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+		// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
+		// it's correctly marked as working.
+		via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -323,13 +461,13 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	}
 
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.ConnectionState.messageCounter.Store(2)
+
 	hostinfo.remotes.ResetBlockedRemotes()
 
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
 	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -339,9 +477,10 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	defer hh.Unlock()
 
 	hostinfo := hh.hostinfo
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if addr.IsValid() {
+		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
 	}
@@ -349,7 +488,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -358,7 +497,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 
@@ -370,37 +509,102 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
 		return true
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	if err != nil {
-		e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
+			Info("Handshake did not contain a certificate")
+		return true
+	}
+
+	remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
+	if err != nil {
+		fp, err := rc.Fingerprint()
+		if err != nil {
+			fp = "<error generating certificate fingerprint>"
+		}
 
-		if f.l.Level > logrus.DebugLevel {
-			e = e.WithField("cert", remoteCert)
+		e := f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
+			WithField("certFingerprint", fp).
+			WithField("certVpnNetworks", rc.Networks())
+
+		if f.l.Level >= logrus.DebugLevel {
+			e = e.WithField("cert", rc)
 		}
 
-		e.Error("Invalid certificate from host")
+		e.Info("Invalid certificate from host")
+		return true
+	}
 
-		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
+	if len(remoteCert.Certificate.Networks()) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("cert", remoteCert).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
+			Info("No networks in certificate")
 		return true
 	}
 
-	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
-	certName := remoteCert.Details.Name
-	fingerprint, _ := remoteCert.Sha256Sum()
-	issuer := remoteCert.Details.Issuer
+	vpnNetworks := remoteCert.Certificate.Networks()
+	certName := remoteCert.Certificate.Name()
+	certVersion := remoteCert.Certificate.Version()
+	fingerprint := remoteCert.Fingerprint
+	issuer := remoteCert.Certificate.Issuer()
+
+	hostinfo.remoteIndexId = hs.Details.ResponderIndex
+	hostinfo.lastHandshakeTime = hs.Details.Time
+
+	// Store their cert and our symmetric keys
+	ci.peerCert = remoteCert
+	ci.dKey = NewNebulaCipherState(dKey)
+	ci.eKey = NewNebulaCipherState(eKey)
+
+	// Make sure the current udpAddr being used is set for responding
+	if addr.IsValid() {
+		hostinfo.SetRemote(addr)
+	} else {
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+	}
+
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	for _, network := range vpnNetworks {
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		vpnAddr := network.Addr()
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("certVersion", certVersion).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
+		return true
+	}
 
 	// Ensure the right host responded
-	if vpnIp != hostinfo.vpnIp {
-		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
-			WithField("udpAddr", addr).WithField("certName", certName).
+	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
+			WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("certVersion", certVersion).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 
@@ -408,16 +612,13 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
-			//TODO: this doesnt know if its being added or is being used for caching a packet
+		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes.BlockRemote(addr)
 
-			// Get the correct remote list for the host we did handshake with
-			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
-
-			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
+				WithField("vpnNetworks", vpnNetworks).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				Info("Blocked addresses for handshakes")
 
@@ -425,8 +626,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 			newHH.packetStore = hh.packetStore
 			hh.packetStore = []*cachedPacket{}
 
-			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-			hostinfo.vpnIp = vpnIp
+			// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnAddrs = vpnAddrs
 			f.sendCloseTunnel(hostinfo)
 		})
 
@@ -437,8 +638,9 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
+		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
@@ -447,30 +649,14 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 		WithField("sentCachedPackets", len(hh.packetStore)).
 		Info("Handshake message received")
 
-	hostinfo.remoteIndexId = hs.Details.ResponderIndex
-	hostinfo.lastHandshakeTime = hs.Details.Time
-
-	// Store their cert and our symmetric keys
-	ci.peerCert = remoteCert
-	ci.dKey = NewNebulaCipherState(dKey)
-	ci.eKey = NewNebulaCipherState(eKey)
-
-	// Make sure the current udpAddr being used is set for responding
-	if addr != nil {
-		hostinfo.SetRemote(addr)
-	} else {
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
-	}
-
 	// Build up the radix for the firewall if we have subnets in the cert
-	hostinfo.CreateRemoteCIDR(remoteCert)
+	hostinfo.vpnAddrs = vpnAddrs
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
-	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
+	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 
-	hostinfo.ConnectionState.messageCounter.Store(2)
-
 	if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
 	}

+ 215 - 149
handshake_manager.go

@@ -6,13 +6,14 @@ import (
 	"crypto/rand"
 	"encoding/binary"
 	"errors"
-	"net"
+	"net/netip"
+	"slices"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 )
 
@@ -34,7 +35,7 @@ var (
 
 type HandshakeConfig struct {
 	tryInterval   time.Duration
-	retries       int
+	retries       int64
 	triggerBuffer int
 	useRelays     bool
 
@@ -45,14 +46,14 @@ type HandshakeManager struct {
 	// Mutex for interacting with the vpnIps and indexes maps
 	syncRWMutex
 
-	vpnIps  map[iputil.VpnIp]*HandshakeHostInfo
+	vpnIps  map[netip.Addr]*HandshakeHostInfo
 	indexes map[uint32]*HandshakeHostInfo
 
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
 	outside                udp.Conn
 	config                 HandshakeConfig
-	OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
+	OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
 	messageMetrics         *MessageMetrics
 	metricInitiated        metrics.Counter
 	metricTimedOut         metrics.Counter
@@ -60,17 +61,17 @@ type HandshakeManager struct {
 	l                      *logrus.Logger
 
 	// can be used to trigger outbound handshake for the given vpnIp
-	trigger chan iputil.VpnIp
+	trigger chan netip.Addr
 }
 
 type HandshakeHostInfo struct {
 	syncMutex
 
-	startTime   time.Time       // Time that we first started trying with this handshake
-	ready       bool            // Is the handshake ready
-	counter     int             // How many attempts have we made so far
-	lastRemotes []*udp.Addr     // Remotes that we sent to during the previous attempt
-	packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+	startTime   time.Time        // Time that we first started trying with this handshake
+	ready       bool             // Is the handshake ready
+	counter     int64            // How many attempts have we made so far
+	lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
+	packetStore []*cachedPacket  // A set of packets to be transmitted once the handshake completes
 
 	hostinfo *HostInfo
 }
@@ -103,14 +104,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
 func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
 		syncRWMutex:            newSyncRWMutex("handshake-manager"),
-		vpnIps:                 map[iputil.VpnIp]*HandshakeHostInfo{},
+		vpnIps:                 map[netip.Addr]*HandshakeHostInfo{},
 		indexes:                map[uint32]*HandshakeHostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
 		config:                 config,
-		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
-		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp]("handshake-manager-timer", config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		trigger:                make(chan netip.Addr, config.triggerBuffer),
+		OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr]("handshake-manager-timer", config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -118,26 +119,26 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context) {
-	clockSource := time.NewTicker(c.config.tryInterval)
+func (hm *HandshakeManager) Run(ctx context.Context) {
+	clockSource := time.NewTicker(hm.config.tryInterval)
 	defer clockSource.Stop()
 
 	for {
 		select {
 		case <-ctx.Done():
 			return
-		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, true)
+		case vpnIP := <-hm.trigger:
+			hm.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now)
+			hm.NextOutboundHandshakeTimerTick(now)
 		}
 	}
 }
 
-func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
-	if addr != nil {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+	if addr.IsValid() {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
@@ -159,18 +160,18 @@ func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packe
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
-	c.OutboundHandshakeTimer.Advance(now)
+func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
+	hm.OutboundHandshakeTimer.Advance(now)
 	for {
-		vpnIp, has := c.OutboundHandshakeTimer.Purge()
+		vpnIp, has := hm.OutboundHandshakeTimer.Purge()
 		if !has {
 			break
 		}
-		c.handleOutbound(vpnIp, false)
+		hm.handleOutbound(vpnIp, false)
 	}
 }
 
-func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
 	hh := hm.queryVpnIp(vpnIp)
 	if hh == nil {
 		return
@@ -208,11 +209,11 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
 	}
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
-	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
+	remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// This is a very specific optimization for a fast lighthouse reply.
@@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 
 	hh.lastRemotes = remotes
 
-	// TODO: this will generate a load of queries for hosts with only 1 ip
+	// This will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
 	if len(remotes) <= 1 && hh.counter == 5 {
@@ -234,8 +235,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
-	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
+	var sentTo []netip.AddrPort
+	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
 		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
@@ -256,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("Handshake message sent")
-	} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
+	} else if hm.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
@@ -267,56 +268,28 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
+			// Don't relay to myself
+			if relay == vpnIp {
 				continue
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
-			if relayHostInfo == nil || relayHostInfo.remote == nil {
+
+			// Don't relay through the host I'm trying to connect to
+			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
+			if found {
+				continue
+			}
+
+			relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
+			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				hm.f.Handshake(*relay)
+				hm.f.Handshake(relay)
 				continue
 			}
-			// Check the relay HostInfo to see if we already established a relay through it
-			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
-				switch existingRelay.State {
-				case Established:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
-				case Requested:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
-					// Re-send the CreateRelay request, in case the previous one was lost.
-					m := NebulaControl{
-						Type:                NebulaControl_CreateRelayRequest,
-						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
-						RelayToIp:           uint32(vpnIp),
-					}
-					msg, err := m.Marshal()
-					if err != nil {
-						hostinfo.logger(hm.l).
-							WithError(err).
-							Error("Failed to marshal Control message to create relay")
-					} else {
-						// This must send over the hostinfo, not over hm.Hosts[ip]
-						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.lightHouse.myVpnIp,
-							"relayTo":             vpnIp,
-							"initiatorRelayIndex": existingRelay.LocalIndex,
-							"relay":               *relay}).
-							Info("send CreateRelayRequest")
-					}
-				default:
-					hostinfo.logger(hm.l).
-						WithField("vpnIp", vpnIp).
-						WithField("state", existingRelay.State).
-						WithField("relay", relayHostInfo.vpnIp).
-						Errorf("Relay unexpected state")
-				}
-			} else {
+			// Check the relay HostInfo to see if we already established a relay through
+			existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
+			if !ok {
 				// No relays exist or requested yet.
-				if relayHostInfo.remote != nil {
+				if relayHostInfo.remote.IsValid() {
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					if err != nil {
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
@@ -325,9 +298,32 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
-						RelayToIp:           uint32(vpnIp),
 					}
+
+					switch relayHostInfo.GetCert().Certificate.Version() {
+					case cert.Version1:
+						if !hm.f.myVpnAddrs[0].Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+							continue
+						}
+
+						if !vpnIp.Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+							continue
+						}
+
+						b := hm.f.myVpnAddrs[0].As4()
+						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+						b = vpnIp.As4()
+						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					case cert.Version2:
+						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					default:
+						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						continue
+					}
+
 					msg, err := m.Marshal()
 					if err != nil {
 						hostinfo.logger(hm.l).
@@ -336,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.lightHouse.myVpnIp,
+							"relayFrom":           hm.f.myVpnAddrs[0],
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
-							"relay":               *relay}).
+							"relay":               relay}).
 							Info("send CreateRelayRequest")
 					}
 				}
+				continue
+			}
+
+			switch existingRelay.State {
+			case Established:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+				hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+			case Disestablished:
+				// Mark this relay as 'requested'
+				relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
+				fallthrough
+			case Requested:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+				// Re-send the CreateRelay request, in case the previous one was lost.
+				m := NebulaControl{
+					Type:                NebulaControl_CreateRelayRequest,
+					InitiatorRelayIndex: existingRelay.LocalIndex,
+				}
+
+				switch relayHostInfo.GetCert().Certificate.Version() {
+				case cert.Version1:
+					if !hm.f.myVpnAddrs[0].Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+						continue
+					}
+
+					if !vpnIp.Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+						continue
+					}
+
+					b := hm.f.myVpnAddrs[0].As4()
+					m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+					b = vpnIp.As4()
+					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+				case cert.Version2:
+					m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+				default:
+					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+					continue
+				}
+				msg, err := m.Marshal()
+				if err != nil {
+					hostinfo.logger(hm.l).
+						WithError(err).
+						Error("Failed to marshal Control message to create relay")
+				} else {
+					// This must send over the hostinfo, not over hm.Hosts[ip]
+					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+					hm.l.WithFields(logrus.Fields{
+						"relayFrom":           hm.f.myVpnAddrs[0],
+						"relayTo":             vpnIp,
+						"initiatorRelayIndex": existingRelay.LocalIndex,
+						"relay":               relay}).
+						Info("send CreateRelayRequest")
+				}
+			case PeerRequested:
+				// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
+				fallthrough
+			default:
+				hostinfo.logger(hm.l).
+					WithField("vpnIp", vpnIp).
+					WithField("state", existingRelay.State).
+					WithField("relay", relay).
+					Errorf("Relay unexpected state")
+
 			}
 		}
 	}
@@ -355,11 +418,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 
 // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
 // The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	// Check the main hostmap and maintain a read lock if our host is not there
+func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	hm.mainHostMap.RLock()
-	if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
-		hm.mainHostMap.RUnlock()
+	h, ok := hm.mainHostMap.Hosts[vpnIp]
+	hm.mainHostMap.RUnlock()
+
+	if ok {
 		// Do not attempt promotion if you are a lighthouse
 		if !hm.lightHouse.amLighthouse {
 			h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
@@ -367,15 +431,14 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 		return h, true
 	}
 
-	defer hm.mainHostMap.RUnlock()
 	return hm.StartHandshake(vpnIp, cacheCb), false
 }
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 
-	if hh, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnAddr]; ok {
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
 			cacheCb(hh)
@@ -386,13 +449,13 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 
 	hostinfo := &HostInfo{
 		syncRWMutex:     newSyncRWMutex("hostinfo"),
-		vpnIp:           vpnIp,
+		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			syncRWMutex:   newSyncRWMutex("relay-state"),
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			syncRWMutex:    newSyncRWMutex("relay-state"),
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}
 
@@ -401,9 +464,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 		hostinfo:  hostinfo,
 		startTime: time.Now(),
 	}
-	hm.vpnIps[vpnIp] = hh
+	hm.vpnIps[vpnAddr] = hh
 	hm.metricInitiated.Inc(1)
-	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
+	hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
 
 	if cacheCb != nil {
 		cacheCb(hh)
@@ -411,21 +474,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// We can trigger the handshake right now
-	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
 	if !doTrigger {
 		// Add any calculated remotes, and trigger early handshake if one found
-		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
 	}
 
 	if doTrigger {
 		select {
-		case hm.trigger <- vpnIp:
+		case hm.trigger <- vpnAddr:
 		default:
 		}
 	}
 
 	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp)
+	hm.lightHouse.QueryServer(vpnAddr)
 	return hostinfo
 }
 
@@ -446,14 +509,14 @@ var (
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
-	c.Lock()
-	defer c.Unlock()
+func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
 	if found && existingHostInfo != nil {
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
@@ -470,31 +533,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			return existingHostInfo, ErrExistingHostInfo
 		}
 
-		existingHostInfo.logger(c.l).Info("Taking new handshake")
+		existingHostInfo.logger(hm.l).Info("Taking new handshake")
 	}
 
-	existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
 	if found {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
 	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
-		return existingIndex, ErrLocalIndexCollision
+		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 	}
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+		hostinfo.logger(hm.l).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 	}
 
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 }
 
@@ -512,7 +575,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(hm.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 	}
 
@@ -549,31 +612,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 }
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	c.Lock()
-	defer c.Unlock()
-	c.unlockedDeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
 }
 
-func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	delete(c.vpnIps, hostinfo.vpnIp)
-	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
+func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	for _, addr := range hostinfo.vpnAddrs {
+		delete(hm.vpnIps, addr)
 	}
 
-	delete(c.indexes, hostinfo.localIndexId)
-	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HandshakeHostInfo{}
+	if len(hm.vpnIps) == 0 {
+		hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+	delete(hm.indexes, hostinfo.localIndexId)
+	if len(hm.indexes) == 0 {
+		hm.indexes = map[uint32]*HandshakeHostInfo{}
+	}
+
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Pending hostmap hostInfo deleted")
 	}
 }
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 		return hh.hostinfo
@@ -582,7 +648,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 
 }
 
-func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
 	hm.RLock()
 	defer hm.RUnlock()
 	return hm.vpnIps[vpnIp]
@@ -602,37 +668,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 }
 
-func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
-	return c.mainHostMap.GetPreferredRanges()
+func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
+	return hm.mainHostMap.GetPreferredRanges()
 }
 
-func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
-	for _, v := range c.vpnIps {
+	for _, v := range hm.vpnIps {
 		f(v.hostinfo)
 	}
 }
 
-func (c *HandshakeManager) ForEachIndex(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
-	for _, v := range c.indexes {
+	for _, v := range hm.indexes {
 		f(v.hostinfo)
 	}
 }
 
-func (c *HandshakeManager) EmitStats() {
-	c.RLock()
-	hostLen := len(c.vpnIps)
-	indexLen := len(c.indexes)
-	c.RUnlock()
+func (hm *HandshakeManager) EmitStats() {
+	hm.RLock()
+	hostLen := len(hm.vpnIps)
+	indexLen := len(hm.indexes)
+	hm.RUnlock()
 
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
-	c.mainHostMap.EmitStats()
+	hm.mainHostMap.EmitStats()
 }
 
 // Utility functions below
@@ -659,6 +725,6 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
 	return index, nil
 }
 
-func hsTimeout(tries int, interval time.Duration) time.Duration {
-	return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
+func hsTimeout(tries int64, interval time.Duration) time.Duration {
+	return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
 }

+ 25 - 18
handshake_manager_test.go

@@ -1,13 +1,12 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
@@ -15,20 +14,20 @@ import (
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
-	mainHM := newHostMap(l, vpncidr)
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	ip := netip.MustParseAddr("172.1.1.2")
+
+	preferredRanges := []netip.Prefix{localrange}
+	mainHM := newHostMap(l)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
 	lh := newTestLighthouse()
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -42,10 +41,10 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 
-	i.remotes = NewRemoteList(nil)
+	i.remotes = NewRemoteList([]netip.Addr{}, nil)
 
 	// Adding something to pending should not affect the main hostmap
-	assert.Len(t, mainHM.Hosts, 0)
+	assert.Empty(t, mainHM.Hosts)
 
 	// Confirm they are in the pending index list
 	assert.Contains(t, blah.vpnIps, ip)
@@ -66,7 +65,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.NotContains(t, blah.vpnIps, ip)
 }
 
-func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
+func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 	for _, i := range tw.t.wheel {
 		n := i.Head
 		for n != nil {
@@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
 type mockEncWriter struct {
 }
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
 	return
 }
 
-func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
 	return
 }
 
-func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
 	return
 }
 
-func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}
+func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
+
+func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
+	return nil
+}
+
+func (mw *mockEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: cert.Version2}
+}

+ 1 - 1
header/header.go

@@ -19,7 +19,7 @@ import (
 // |-----------------------------------------------------------------------|
 // |                               payload...                              |
 
-type m map[string]interface{}
+type m = map[string]any
 
 const (
 	Version uint8 = 1

+ 2 - 1
header/header_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 type headerTest struct {
@@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
 
 func TestHeader_MarshalJSON(t *testing.T) {
 	b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
-	assert.Nil(t, err)
+	require.NoError(t, err)
 	assert.Equal(
 		t,
 		"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",

+ 209 - 138
hostmap.go

@@ -3,17 +3,16 @@ package nebula
 import (
 	"errors"
 	"net"
+	"net/netip"
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 // const ProbeLen = 100
@@ -35,6 +34,7 @@ const (
 	Requested = iota
 	PeerRequested
 	Established
+	Disestablished
 )
 
 const (
@@ -48,7 +48,7 @@ type Relay struct {
 	State       int
 	LocalIndex  uint32
 	RemoteIndex uint32
-	PeerIp      iputil.VpnIp
+	PeerAddr    netip.Addr
 }
 
 type HostMap struct {
@@ -56,9 +56,8 @@ type HostMap struct {
 	Indexes         map[uint32]*HostInfo
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	RemoteIndexes   map[uint32]*HostInfo
-	Hosts           map[iputil.VpnIp]*HostInfo
-	preferredRanges atomic.Pointer[[]*net.IPNet]
-	vpnCIDR         *net.IPNet
+	Hosts           map[netip.Addr]*HostInfo
+	preferredRanges atomic.Pointer[[]netip.Prefix]
 	l               *logrus.Logger
 }
 
@@ -68,17 +67,42 @@ type HostMap struct {
 type RelayState struct {
 	syncRWMutex
 
-	relays        map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[iputil.VpnIp]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay         // Maps a local index to some Relay info
+	relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
+	// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
+	// the RelayState Lock held)
+	relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx  map[uint32]*Relay     // Maps a local index to some Relay info
 }
 
-func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
+func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
 	delete(rs.relays, ip)
 }
 
+func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByAddr[vpnIp]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
+func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByIdx[idx]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
 func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	rs.RLock()
 	defer rs.RUnlock()
@@ -89,34 +113,34 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 }
 
-func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[ip]
+	r, ok := rs.relayForByAddr[addr]
 	return r, ok
 }
 
-func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) {
+func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
 	rs.relays[ip] = struct{}{}
 }
 
-func (rs *RelayState) CopyRelayIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayIps() []netip.Addr {
 	rs.RLock()
 	defer rs.RUnlock()
-	ret := make([]iputil.VpnIp, 0, len(rs.relays))
+	ret := make([]netip.Addr, 0, len(rs.relays))
 	for ip := range rs.relays {
 		ret = append(ret, ip)
 	}
 	return ret
 }
 
-func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	defer rs.RUnlock()
-	currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp))
-	for relayIp := range rs.relayForByIp {
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
+	for relayIp := range rs.relayForByAddr {
 		currentRelays = append(currentRelays, relayIp)
 	}
 	return currentRelays
@@ -132,22 +156,10 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 	return ret
 }
 
-func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
-	rs.Lock()
-	defer rs.Unlock()
-	r, ok := rs.relayForByIdx[localIdx]
-	if !ok {
-		return iputil.VpnIp(0), false
-	}
-	delete(rs.relayForByIdx, localIdx)
-	delete(rs.relayForByIp, r.PeerIp)
-	return r.PeerIp, true
-}
-
-func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
+func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	defer rs.Unlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	if !ok {
 		return false
 	}
@@ -155,7 +167,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bo
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return true
 }
 
@@ -170,14 +182,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return &newRelay, true
 }
 
-func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	return r, ok
 }
 
@@ -188,25 +200,31 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 	return r, ok
 }
 
-func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
+func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	defer rs.Unlock()
-	rs.relayForByIp[ip] = r
+	rs.relayForByAddr[ip] = r
 	rs.relayForByIdx[idx] = r
 }
 
 type HostInfo struct {
 	syncRWMutex
-	remote          *udp.Addr
+	remote          netip.AddrPort
 	remotes         *RemoteList
 	promoteCounter  atomic.Uint32
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	localIndexId    uint32
-	vpnIp           iputil.VpnIp
-	recvError       atomic.Uint32
-	remoteCidr      *cidr.Tree4[struct{}]
-	relayState      RelayState
+
+	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
+	// The host may have other vpn addresses that are outside our
+	// vpn networks but were removed because they are not usable
+	vpnAddrs  []netip.Addr
+	recvError atomic.Uint32
+
+	// networks are both all vpn and unsafe networks assigned to this host
+	networks   *bart.Table[struct{}]
+	relayState RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
 	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
@@ -227,7 +245,7 @@ type HostInfo struct {
 	lastHandshakeTime uint64
 
 	lastRoam       time.Time
-	lastRoamRemote *udp.Addr
+	lastRoamRemote netip.AddrPort
 
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Synchronised via hostmap lock and not the hostinfo lock.
@@ -254,40 +272,38 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
-	hm := newHostMap(l, vpnCIDR)
+func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
+	hm := newHostMap(l)
 
 	hm.reload(c, true)
 	c.RegisterReloadCallback(func(c *config.C) {
 		hm.reload(c, false)
 	})
 
-	l.WithField("network", hm.vpnCIDR.String()).
-		WithField("preferredRanges", hm.GetPreferredRanges()).
+	l.WithField("preferredRanges", hm.GetPreferredRanges()).
 		Info("Main HostMap created")
 
 	return hm
 }
 
-func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
+func newHostMap(l *logrus.Logger) *HostMap {
 	return &HostMap{
 		syncRWMutex:   newSyncRWMutex("hostmap"),
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
-		Hosts:         map[iputil.VpnIp]*HostInfo{},
-		vpnCIDR:       vpnCIDR,
+		Hosts:         map[netip.Addr]*HostInfo{},
 		l:             l,
 	}
 }
 
 func (hm *HostMap) reload(c *config.C, initial bool) {
 	if initial || c.HasChanged("preferred_ranges") {
-		var preferredRanges []*net.IPNet
+		var preferredRanges []netip.Prefix
 		rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
 
 		for _, rawPreferredRange := range rawPreferredRanges {
-			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
+			preferredRange, err := netip.ParsePrefix(rawPreferredRange)
 
 			if err != nil {
 				hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
@@ -319,17 +335,6 @@ func (hm *HostMap) EmitStats() {
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 
-func (hm *HostMap) RemoveRelay(localIdx uint32) {
-	hm.Lock()
-	_, ok := hm.Relays[localIdx]
-	if !ok {
-		hm.Unlock()
-		return
-	}
-	delete(hm.Relays, localIdx)
-	hm.Unlock()
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
@@ -349,48 +354,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 }
 
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
-	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	// Get the current primary, if it exists
+	oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
+
+	// Every address in the hostinfo gets elevated to primary
+	for _, vpnAddr := range hostinfo.vpnAddrs {
+		//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
+		// indexes so it should be fine.
+		hm.Hosts[vpnAddr] = hostinfo
+	}
+
+	// If we are already primary then we won't bother re-linking
 	if oldHostinfo == hostinfo {
 		return
 	}
 
+	// Unlink this hostinfo
 	if hostinfo.prev != nil {
 		hostinfo.prev.next = hostinfo.next
 	}
-
 	if hostinfo.next != nil {
 		hostinfo.next.prev = hostinfo.prev
 	}
 
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
+	// If there wasn't a previous primary then clear out any links
 	if oldHostinfo == nil {
+		hostinfo.next = nil
+		hostinfo.prev = nil
 		return
 	}
 
+	// Relink the hostinfo as primary
 	hostinfo.next = oldHostinfo
 	oldHostinfo.prev = hostinfo
 	hostinfo.prev = nil
 }
 
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	for _, addr := range hostinfo.vpnAddrs {
+		h := hm.Hosts[addr]
+		for h != nil {
+			if h == hostinfo {
+				hm.unlockedInnerDeleteHostInfo(h, addr)
+			}
+			h = h.next
+		}
+	}
+}
+
+func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
+	primary, ok := hm.Hosts[addr]
+	isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
 	if ok && primary == hostinfo {
-		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
-		delete(hm.Hosts, hostinfo.vpnIp)
+		// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, addr)
 		if len(hm.Hosts) == 0 {
-			hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 
 		if hostinfo.next != nil {
-			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
-			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
+			hm.Hosts[addr] = hostinfo.next
 			// It is primary, there is no previous hostinfo now
 			hostinfo.next.prev = nil
 		}
 
 	} else {
-		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		// Relink if we were in the middle of multiple hostinfos for this vpn addr
 		if hostinfo.prev != nil {
 			hostinfo.prev.next = hostinfo.next
 		}
@@ -420,10 +450,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
 
+	if isLastHostinfo {
+		// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
+		// hops as 'Requested' so that new relay tunnels are created in the future.
+		hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
+	}
+	// Clean up any local relay indexes for which I am acting as a relay hop
 	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
 		delete(hm.Relays, localRelayIdx)
 	}
@@ -462,11 +498,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 }
 
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
-	return hm.queryVpnIp(vpnIp, nil)
+func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
+	return hm.queryVpnAddr(vpnIp, nil)
 }
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -474,17 +510,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
 	if !ok {
 		return nil, nil, errors.New("unable to find host")
 	}
+
 	for h != nil {
-		r, ok := h.relayState.QueryRelayForByIp(targetIp)
-		if ok && r.State == Established {
-			return h, r, nil
+		for _, targetIp := range targetIps {
+			r, ok := h.relayState.QueryRelayForByIp(targetIp)
+			if ok && r.State == Established {
+				return h, r, nil
+			}
 		}
 		h = h.next
 	}
+
 	return nil, nil, errors.New("unable to find host with relay")
 }
 
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
+	for _, relayHostIp := range hi.relayState.CopyRelayIps() {
+		if h, ok := hm.Hosts[relayHostIp]; ok {
+			for h != nil {
+				h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+				h = h.next
+			}
+		}
+	}
+	for _, rs := range hi.relayState.CopyAllRelayFor() {
+		if rs.Type == ForwardingType {
+			if h, ok := hm.Hosts[rs.PeerAddr]; ok {
+				for h != nil {
+					h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+					h = h.next
+				}
+			}
+		}
+	}
+}
+
+func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -505,25 +566,30 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
-		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
+		dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
 	}
-
-	existing := hm.Hosts[hostinfo.vpnIp]
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
-	if existing != nil {
-		hostinfo.next = existing
-		existing.prev = hostinfo
+	for _, addr := range hostinfo.vpnAddrs {
+		hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
 	}
 
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
+		hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
 			Debug("Hostmap vpnIp added")
 	}
+}
+
+func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
+	existing := hm.Hosts[vpnAddr]
+	hm.Hosts[vpnAddr] = hostinfo
+
+	if existing != nil && existing != hostinfo {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
 
 	i := 1
 	check := hostinfo
@@ -536,12 +602,12 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 }
 
-func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
+func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
 	return *hm.preferredRanges.Load()
 }
 
-func (hm *HostMap) ForEachVpnIp(f controlEach) {
+func (hm *HostMap) ForEachVpnAddr(f controlEach) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -561,14 +627,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
-func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
+func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
 	if c%ifce.tryPromoteEvery.Load() == 0 {
 		remote := i.remote
 
 		// return early if we are already on a preferred remote
-		if remote != nil {
-			rIP := remote.IP
+		if remote.IsValid() {
+			rIP := remote.Addr()
 			for _, l := range preferredRanges {
 				if l.Contains(rIP) {
 					return
@@ -576,8 +642,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 		}
 
-		i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
-			if remote != nil && (addr == nil || !preferred) {
+		i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
+			if remote.IsValid() && (!addr.IsValid() || !preferred) {
 				return
 			}
 
@@ -595,34 +661,34 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 		}
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp)
+		ifce.lightHouse.QueryServer(i.vpnAddrs[0])
 	}
 }
 
-func (i *HostInfo) GetCert() *cert.NebulaCertificate {
+func (i *HostInfo) GetCert() *cert.CachedCertificate {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
 	}
 	return nil
 }
 
-func (i *HostInfo) SetRemote(remote *udp.Addr) {
+func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
-	if !i.remote.Equals(remote) {
-		i.remote = remote.Copy()
-		i.remotes.LearnRemote(i.vpnIp, remote.Copy())
+	if i.remote != remote {
+		i.remote = remote
+		i.remotes.LearnRemote(i.vpnAddrs[0], remote)
 	}
 }
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
-	if newRemote == nil {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
+	if !newRemote.IsValid() {
 		// relays have nil udp Addrs
 		return false
 	}
 	currentRemote := i.remote
-	if currentRemote == nil {
+	if !currentRemote.IsValid() {
 		i.SetRemote(newRemote)
 		return true
 	}
@@ -632,11 +698,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	newIsPreferred := false
 	for _, l := range hm.GetPreferredRanges() {
 		// return early if we are already on a preferred remote
-		if l.Contains(currentRemote.IP) {
+		if l.Contains(currentRemote.Addr()) {
 			return false
 		}
 
-		if l.Contains(newRemote.IP) {
+		if l.Contains(newRemote.Addr()) {
 			newIsPreferred = true
 		}
 	}
@@ -644,7 +710,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	if newIsPreferred {
 		// Consider this a roaming event
 		i.lastRoam = time.Now()
-		i.lastRoamRemote = currentRemote.Copy()
+		i.lastRoamRemote = currentRemote
 
 		i.SetRemote(newRemote)
 
@@ -661,21 +727,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 }
 
-func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
-	if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
+func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
+	if len(networks) == 1 && len(unsafeNetworks) == 0 {
 		// Simple case, no CIDRTree needed
 		return
 	}
 
-	remoteCidr := cidr.NewTree4[struct{}]()
-	for _, ip := range c.Details.Ips {
-		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+	i.networks = new(bart.Table[struct{}])
+	for _, network := range networks {
+		i.networks.Insert(network, struct{}{})
 	}
 
-	for _, n := range c.Details.Subnets {
-		remoteCidr.AddCIDR(n, struct{}{})
+	for _, network := range unsafeNetworks {
+		i.networks.Insert(network, struct{}{})
 	}
-	i.remoteCidr = remoteCidr
 }
 
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@@ -683,13 +748,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 	}
 
-	li := l.WithField("vpnIp", i.vpnIp).
+	li := l.WithField("vpnAddrs", i.vpnAddrs).
 		WithField("localIndex", i.localIndexId).
 		WithField("remoteIndex", i.remoteIndexId)
 
 	if connState := i.ConnectionState; connState != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
-			li = li.WithField("certName", peerCert.Details.Name)
+			li = li.WithField("certName", peerCert.Certificate.Name())
 		}
 	}
 
@@ -698,9 +763,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 // Utility functions
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
+func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
-	var ips []net.IP
+	var finalAddrs []netip.Addr
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
@@ -712,30 +777,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
 			continue
 		}
 		addrs, _ := i.Addrs()
-		for _, addr := range addrs {
-			var ip net.IP
-			switch v := addr.(type) {
+		for _, rawAddr := range addrs {
+			var addr netip.Addr
+			switch v := rawAddr.(type) {
 			case *net.IPNet:
 				//continue
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			case *net.IPAddr:
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
+			}
+
+			if !addr.IsValid() {
+				if l.Level >= logrus.DebugLevel {
+					l.WithField("localAddr", rawAddr).Debug("addr was invalid")
+				}
+				continue
 			}
+			addr = addr.Unmap()
 
-			//TODO: Filtering out link local for now, this is probably the most correct thing
-			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
-				allow := allowList.Allow(ip)
+			if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
+				isAllowed := allowList.Allow(addr)
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
 				}
-				if !allow {
+				if !isAllowed {
 					continue
 				}
 
-				ips = append(ips, ip)
+				finalAddrs = append(finalAddrs, addr)
 			}
 		}
 	}
-	return &ips
+	return finalAddrs
 }

+ 27 - 46
hostmap_test.go

@@ -1,7 +1,7 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"testing"
 
 	"github.com/slackhq/nebula/config"
@@ -11,20 +11,14 @@ import (
 
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
-	)
+	hm := newHostMap(l)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
-	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
-	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
-	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -32,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -47,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -62,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -77,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -91,22 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
 
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
-	)
+	hm := newHostMap(l)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
-	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
-	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
-	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
-	h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
-	h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
+	h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
+	h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -122,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -141,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -159,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -175,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -189,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
@@ -201,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 }
 
@@ -209,16 +197,9 @@ func TestHostMap_reload(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 
-	hm := NewHostMapFromConfig(
-		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
-		c,
-	)
+	hm := NewHostMapFromConfig(l, c)
 
-	toS := func(ipn []*net.IPNet) []string {
+	toS := func(ipn []netip.Prefix) []string {
 		var s []string
 		for _, n := range ipn {
 			s = append(s, n.String())
@@ -229,8 +210,8 @@ func TestHostMap_reload(t *testing.T) {
 	assert.Empty(t, hm.GetPreferredRanges())
 
 	c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
-	assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
+	assert.Equal(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))
 
 	c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
-	assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
+	assert.Equal(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
 }

+ 5 - 3
hostmap_tester.go

@@ -5,10 +5,12 @@ package nebula
 
 // This file contains functions used to export information to the e2e testing framework
 
-import "github.com/slackhq/nebula/iputil"
+import (
+	"net/netip"
+)
 
-func (i *HostInfo) GetVpnIp() iputil.VpnIp {
-	return i.vpnIp
+func (i *HostInfo) GetVpnAddrs() []netip.Addr {
+	return i.vpnAddrs
 }
 
 func (i *HostInfo) GetLocalIndex() uint32 {

+ 117 - 44
inside.go

@@ -1,12 +1,14 @@
 package nebula
 
 import (
+	"net/netip"
+
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
-	"github.com/slackhq/nebula/udp"
+	"github.com/slackhq/nebula/routing"
 )
 
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -19,14 +21,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
-		return
+	if f.dropLocalBroadcast {
+		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
+		if found {
+			return
+		}
 	}
 
-	if fwPacket.RemoteIP == f.myVpnIp {
+	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
+	if found {
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
-		// routes packets from the Nebula IP to the Nebula IP through the Nebula
+		// routes packets from the Nebula addr to the Nebula addr through the Nebula
 		// TUN device.
 		if immediatelyForwardToSelf {
 			_, err := f.readers[q].Write(packet)
@@ -35,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 			}
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
-		// routes packets from the nebula IP to the nebula IP through the loopback device.
+		// routes packets from the nebula addr to the nebula addr through the loopback device.
 		return
 	}
 
-	// Ignore broadcast packets
-	if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
+	// Ignore multicast packets
+	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", fwPacket.RemoteIP).
+			f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
 				WithField("fwPacket", fwPacket).
-				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
+				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
 		return
 	}
@@ -64,7 +70,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
+		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
 
 	} else {
 		f.rejectInside(packet, out, q)
@@ -113,24 +119,97 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 	}
 
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
+	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 
-func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
-	f.getOrHandshake(vpnIp, nil)
+// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
+func (f *Interface) Handshake(vpnAddr netip.Addr) {
+	f.getOrHandshakeNoRouting(vpnAddr, nil)
 }
 
-// getOrHandshake returns nil if the vpnIp is not routable.
+// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
-		vpnIp = f.inside.RouteFor(vpnIp)
-		if vpnIp == 0 {
-			return nil, false
+func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
+	if found {
+		return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
+	}
+
+	return nil, false
+}
+
+// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
+// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
+func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+
+	destinationAddr := fwPacket.RemoteAddr
+
+	hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
+
+	// Host is inside the mesh, no routing required
+	if hostinfo != nil {
+		return hostinfo, ready
+	}
+
+	gateways := f.inside.RoutesFor(destinationAddr)
+
+	switch len(gateways) {
+	case 0:
+		return nil, false
+	case 1:
+		// Single gateway route
+		return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
+	default:
+		// Multi gateway route, perform ECMP categorization
+		gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
+
+		if !balancingOk {
+			// This happens if the gateway buckets were not calculated, this _should_ never happen
+			f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
 		}
+
+		var handshakeInfoForChosenGateway *HandshakeHostInfo
+		var hhReceiver = func(hh *HandshakeHostInfo) {
+			handshakeInfoForChosenGateway = hh
+		}
+
+		// Store the handshakeHostInfo for later.
+		// If this node is not reachable we will attempt other nodes, if none are reachable we will
+		// cache the packet for this gateway.
+		if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
+			return hostinfo, true
+		}
+
+		// It appears the selected gateway cannot be reached, find another gateway to fallback on.
+		// The current implementation breaks ECMP but that seems better than no connectivity.
+		// If ECMP is also required when a gateway is down then connectivity status
+		// for each gateway needs to be kept and the weights recalculated when they go up or down.
+		// This would also need to interact with unsafe_route updates through reloading the config or
+		// use of the use_system_route_table option
+
+		if f.l.Level >= logrus.DebugLevel {
+			f.l.WithField("destination", destinationAddr).
+				WithField("originalGateway", gatewayAddr).
+				Debugln("Calculated gateway for ECMP not available, attempting other gateways")
+		}
+
+		for i := range gateways {
+			// Skip the gateway that failed previously
+			if gateways[i].Addr() == gatewayAddr {
+				continue
+			}
+
+			// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
+			if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
+				return hostinfo, true
+			}
+		}
+
+		// No gateways reachable, cache the packet in the originally chosen gateway
+		cacheCallback(handshakeInfoForChosenGateway)
+		return hostinfo, false
 	}
 
-	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -152,19 +231,19 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 	}
 
-	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 
-// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
+	hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", vpnIp).
-				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
+			f.l.WithField("vpnAddr", vpnAddr).
+				Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
 		}
 		return
 	}
@@ -182,10 +261,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 
-func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
+func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
@@ -255,12 +334,11 @@ func (f *Interface) SendVia(via *HostInfo,
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
-		//TODO: log warning
 		return
 	}
-	useRelay := remote == nil && hostinfo.remote == nil
+	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
 	fullOut := out
 
 	if useRelay {
@@ -284,14 +362,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	f.connectionManager.Out(hostinfo.localIndexId)
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
-	// all our IPs and enable a faster roaming.
+	// all our addrs and enable a faster roaming.
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp)
+		f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 	}
 
@@ -308,13 +386,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		return
 	}
 
-	if remote != nil {
+	if remote.IsValid() {
 		err = f.writers[q].WriteTo(out, remote)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
 				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 		}
-	} else if hostinfo.remote != nil {
+	} else if hostinfo.remote.IsValid() {
 		err = f.writers[q].WriteTo(out, hostinfo.remote)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
@@ -323,7 +401,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
 			if err != nil {
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
@@ -334,8 +412,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		}
 	}
 }
-
-func isMulticast(ip iputil.VpnIp) bool {
-	// Class D multicast
-	return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
-}

+ 84 - 53
interface.go

@@ -5,18 +5,18 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"runtime"
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
@@ -28,7 +28,6 @@ type InterfaceConfig struct {
 	Outside                 udp.Conn
 	Inside                  overlay.Device
 	pki                     *PKI
-	Cipher                  string
 	Firewall                *Firewall
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
@@ -52,25 +51,27 @@ type InterfaceConfig struct {
 }
 
 type Interface struct {
-	hostMap            *HostMap
-	outside            udp.Conn
-	inside             overlay.Device
-	pki                *PKI
-	cipher             string
-	firewall           *Firewall
-	connectionManager  *connectionManager
-	handshakeManager   *HandshakeManager
-	serveDns           bool
-	createTime         time.Time
-	lightHouse         *LightHouse
-	localBroadcast     iputil.VpnIp
-	myVpnIp            iputil.VpnIp
-	dropLocalBroadcast bool
-	dropMulticast      bool
-	routines           int
-	disconnectInvalid  atomic.Bool
-	closed             atomic.Bool
-	relayManager       *relayManager
+	hostMap               *HostMap
+	outside               udp.Conn
+	inside                overlay.Device
+	pki                   *PKI
+	firewall              *Firewall
+	connectionManager     *connectionManager
+	handshakeManager      *HandshakeManager
+	serveDns              bool
+	createTime            time.Time
+	lightHouse            *LightHouse
+	myBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
+	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	dropLocalBroadcast    bool
+	dropMulticast         bool
+	routines              int
+	disconnectInvalid     atomic.Bool
+	closed                atomic.Bool
+	relayManager          *relayManager
 
 	tryPromoteEvery atomic.Uint32
 	reQueryEvery    atomic.Uint32
@@ -102,9 +103,11 @@ type EncWriter interface {
 		out []byte,
 		nocopy bool,
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp iputil.VpnIp)
+	Handshake(vpnAddr netip.Addr)
+	GetHostInfo(vpnAddr netip.Addr) *HostInfo
+	GetCertState() *CertState
 }
 
 type sendRecvErrorConfig uint8
@@ -115,10 +118,10 @@ const (
 	sendRecvErrorPrivate
 )
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
 	switch s {
 	case sendRecvErrorPrivate:
-		return ip.IsPrivate()
+		return endpoint.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 		return true
 	case sendRecvErrorNever:
@@ -155,28 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 	}
 
-	certificate := c.pki.GetCertState().Certificate
-	myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
+	cs := c.pki.getCertState()
 	ifce := &Interface{
-		pki:                c.pki,
-		hostMap:            c.HostMap,
-		outside:            c.Outside,
-		inside:             c.Inside,
-		cipher:             c.Cipher,
-		firewall:           c.Firewall,
-		serveDns:           c.ServeDns,
-		handshakeManager:   c.HandshakeManager,
-		createTime:         time.Now(),
-		lightHouse:         c.lightHouse,
-		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
-		dropLocalBroadcast: c.DropLocalBroadcast,
-		dropMulticast:      c.DropMulticast,
-		routines:           c.routines,
-		version:            c.version,
-		writers:            make([]udp.Conn, c.routines),
-		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnIp:            myVpnIp,
-		relayManager:       c.relayManager,
+		pki:                   c.pki,
+		hostMap:               c.HostMap,
+		outside:               c.Outside,
+		inside:                c.Inside,
+		firewall:              c.Firewall,
+		serveDns:              c.ServeDns,
+		handshakeManager:      c.HandshakeManager,
+		createTime:            time.Now(),
+		lightHouse:            c.lightHouse,
+		dropLocalBroadcast:    c.DropLocalBroadcast,
+		dropMulticast:         c.DropMulticast,
+		routines:              c.routines,
+		version:               c.version,
+		writers:               make([]udp.Conn, c.routines),
+		readers:               make([]io.ReadWriteCloser, c.routines),
+		myVpnNetworks:         cs.myVpnNetworks,
+		myVpnNetworksTable:    cs.myVpnNetworksTable,
+		myVpnAddrs:            cs.myVpnAddrs,
+		myVpnAddrsTable:       cs.myVpnAddrsTable,
+		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
+		relayManager:          c.relayManager,
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
@@ -210,7 +214,7 @@ func (f *Interface) activate() {
 		f.l.WithError(err).Error("Failed to get udp listen address")
 	}
 
-	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
+	f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
@@ -251,16 +255,22 @@ func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 
 	var li udp.Conn
-	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 		li = f.writers[i]
 	} else {
 		li = f.outside
 	}
 
+	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
+	plaintext := make([]byte, udp.MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+
+	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
+		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+	})
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -317,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
@@ -400,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	udpStats := udp.NewUDPStatsEmitter(f.writers)
 
 	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
+	certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
+	certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
 
 	for {
 		select {
@@ -409,11 +421,30 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
-			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
+
+			certState := f.pki.getCertState()
+			defaultCrt := certState.GetDefaultCertificate()
+			certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
+			certDefaultVersion.Update(int64(defaultCrt.Version()))
+
+			// Report the max certificate version we are capable of using
+			if certState.v2Cert != nil {
+				certMaxVersion.Update(int64(certState.v2Cert.Version()))
+			} else {
+				certMaxVersion.Update(int64(certState.v1Cert.Version()))
+			}
 		}
 	}
 }
 
+func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return f.hostMap.QueryVpnAddr(vpnIp)
+}
+
+func (f *Interface) GetCertState() *CertState {
+	return f.pki.getCertState()
+}
+
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 

+ 0 - 93
iputil/util.go

@@ -1,93 +0,0 @@
-package iputil
-
-import (
-	"encoding/binary"
-	"fmt"
-	"net"
-	"net/netip"
-)
-
-type VpnIp uint32
-
-const maxIPv4StringLen = len("255.255.255.255")
-
-func (ip VpnIp) String() string {
-	b := make([]byte, maxIPv4StringLen)
-
-	n := ubtoa(b, 0, byte(ip>>24))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip>>16&255))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip>>8&255))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip&255))
-	return string(b[:n])
-}
-
-func (ip VpnIp) MarshalJSON() ([]byte, error) {
-	return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
-}
-
-func (ip VpnIp) ToIP() net.IP {
-	nip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(nip, uint32(ip))
-	return nip
-}
-
-func (ip VpnIp) ToNetIpAddr() netip.Addr {
-	var nip [4]byte
-	binary.BigEndian.PutUint32(nip[:], uint32(ip))
-	return netip.AddrFrom4(nip)
-}
-
-func Ip2VpnIp(ip []byte) VpnIp {
-	if len(ip) == 16 {
-		return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
-	}
-	return VpnIp(binary.BigEndian.Uint32(ip))
-}
-
-func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
-	addr, ok := netip.AddrFromSlice(ip)
-	if !ok {
-		return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
-	}
-	return addr, nil
-}
-
-func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
-	addr, err := ToNetIpAddr(ipNet.IP)
-	if err != nil {
-		return netip.Prefix{}, err
-	}
-	ones, bits := ipNet.Mask.Size()
-	if ones == 0 && bits == 0 {
-		return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
-	}
-	return netip.PrefixFrom(addr, ones), nil
-}
-
-// ubtoa encodes the string form of the integer v to dst[start:] and
-// returns the number of bytes written to dst. The caller must ensure
-// that dst has sufficient length.
-func ubtoa(dst []byte, start int, v byte) int {
-	if v < 10 {
-		dst[start] = v + '0'
-		return 1
-	} else if v < 100 {
-		dst[start+1] = v%10 + '0'
-		dst[start] = v/10 + '0'
-		return 2
-	}
-
-	dst[start+2] = v%10 + '0'
-	dst[start+1] = (v/10)%10 + '0'
-	dst[start] = v/100 + '0'
-	return 3
-}

+ 0 - 17
iputil/util_test.go

@@ -1,17 +0,0 @@
-package iputil
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestVpnIp_String(t *testing.T) {
-	assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
-	assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
-	assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
-	assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
-	assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
-	assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
-}

File diff suppressed because it is too large
+ 382 - 287
lighthouse.go


+ 256 - 247
lighthouse_test.go

@@ -2,150 +2,151 @@ package nebula
 
 import (
 	"context"
+	"encoding/binary"
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 
+	"github.com/gaissmai/bart"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
-	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
-	"gopkg.in/yaml.v2"
+	"github.com/stretchr/testify/require"
+	"gopkg.in/yaml.v3"
 )
 
-//TODO: Add a test to ensure udpAddr is copied and not reused
-
 func TestOldIPv4Only(t *testing.T) {
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
-	var m Ip4AndPort
+	var m V4AddrPort
 	err := m.Unmarshal(b)
-	assert.NoError(t, err)
-	assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
-}
-
-func TestNewLhQuery(t *testing.T) {
-	myIp := net.ParseIP("192.1.1.1")
-	myIpint := iputil.Ip2VpnIp(myIp)
-
-	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIpint)
-
-	// The result should be a nebulameta protobuf
-	assert.IsType(t, &NebulaMeta{}, a)
-
-	// It should also Marshal fine
-	b, err := a.Marshal()
-	assert.Nil(t, err)
-
-	// and then Unmarshal fine
-	n := &NebulaMeta{}
-	err = n.Unmarshal(b)
-	assert.Nil(t, err)
-
+	require.NoError(t, err)
+	ip := netip.MustParseAddr("10.1.1.1")
+	bp := ip.As4()
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
 }
 
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
-	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
-	assert.Nil(t, err)
+	c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1}}
+	c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
-	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
-	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
+	c.Settings["lighthouse"] = map[string]any{"hosts": []any{lh1, lh2}}
+	c.Settings["static_host_map"] = map[string]any{lh1: []any{"100.1.1.1:4242"}}
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{
-		"hosts":    []interface{}{lh1},
+	c.Settings["lighthouse"] = map[string]any{
+		"hosts":    []any{lh1},
 		"interval": "1s",
 	}
 
-	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
-	assert.NoError(t, err)
+	c.Settings["static_host_map"] = map[string]any{lh1: []any{"1.1.1.1:4242"}}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
 
 	// The first one routine is kicked off by main.go currently, lets make sure that one dies
-	c.ReloadConfigString("lighthouse:\n  interval: 5")
+	require.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 5"))
 	assert.Equal(t, int64(5), lh.interval.Load())
 
 	// Subsequent calls are killed off by the LightHouse.Reload function
-	c.ReloadConfigString("lighthouse:\n  interval: 10")
+	require.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 10"))
 	assert.Equal(t, int64(10), lh.interval.Load())
 
 	// If this completes then nothing is stealing our reload routine
-	c.ReloadConfigString("lighthouse:\n  interval: 11")
+	require.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 11"))
 	assert.Equal(t, int64(11), lh.interval.Load())
 }
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
-
-	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
-	if !assert.NoError(b, err) {
-		b.Fatal()
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
 	}
 
-	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
-	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = NewRemoteList(nil)
-	lh.addrMap[3].unlockedSetV4(
-		3,
-		3,
-		[]*Ip4AndPort{
-			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
-			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+	c := config.NewC(l)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(b, err)
+
+	hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
+	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
+
+	vpnIp3 := netip.MustParseAddr("0.0.0.3")
+	lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
+	lh.addrMap[vpnIp3].unlockedSetV4(
+		vpnIp3,
+		vpnIp3,
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
+			netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
-	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
-	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = NewRemoteList(nil)
-	lh.addrMap[2].unlockedSetV4(
-		3,
-		3,
-		[]*Ip4AndPort{
-			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
-			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
+	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
+	vpnIp2 := netip.MustParseAddr("0.0.0.3")
+	lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
+	lh.addrMap[vpnIp2].unlockedSetV4(
+		vpnIp3,
+		vpnIp3,
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
+			netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	mw := &mockEncWriter{}
 
+	hi := []netip.Addr{vpnIp2}
 	b.Run("notfound", func(b *testing.B) {
 		lhh := lh.NewRequestHandler()
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
-				VpnIp:       4,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  4,
+				V4AddrPorts: nil,
 			},
 		}
 		p, err := req.Marshal()
-		assert.NoError(b, err)
+		require.NoError(b, err)
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, 2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 	})
 	b.Run("found", func(b *testing.B) {
@@ -153,15 +154,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
-				VpnIp:       3,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  3,
+				V4AddrPorts: nil,
 			},
 		}
 		p, err := req.Marshal()
-		assert.NoError(b, err)
+		require.NoError(b, err)
 
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, 2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 	})
 }
@@ -169,71 +170,80 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func TestLighthouse_Memory(t *testing.T) {
 	l := test.NewLogger()
 
-	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
-	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
-	myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
-	myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
-	myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
-	myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
-	myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
-	myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
-	myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
-	myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
-	myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
-	myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
-	myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
-
-	theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
-	theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
-	theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
-	theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
-	theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
-	theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+	myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
+	myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
+	myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
+	myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
+	myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
+	myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
+	myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
+	myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
+	myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
+	myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
+	myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
+	myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
+	myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+	theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
+	theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
+	theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
+	theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
+	theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
+	theirVpnIp := netip.MustParseAddr("10.128.0.3")
 
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
-	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
-	assert.NoError(t, err)
+	c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
+	c.Settings["listen"] = map[string]any{"port": 4242}
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	lh.ifce = &mockEncWriter{}
+	require.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
 
 	// Ensure we don't accumulate addresses
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
 
 	// Grow it back to 2
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	// Update a different host and ask about it
-	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Have both hosts ask about the other
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Make sure we didn't get changed
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	// Ensure proper ordering and limiting
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
 	newLHHostUpdate(
 		myUdpAddr0,
 		myVpnIp,
-		[]*udp.Addr{
+		[]netip.AddrPort{
 			myUdpAddr1,
 			myUdpAddr2,
 			myUdpAddr3,
@@ -251,46 +261,60 @@ func TestLighthouse_Memory(t *testing.T) {
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 		t,
-		r.msg.Details.Ip4AndPorts,
+		r.msg.Details.V4AddrPorts,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 
 	// Make sure we won't add ips in our vpn network
-	bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
-	bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
-	good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
+	bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
+	bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
+	good := netip.MustParseAddrPort("1.128.0.99:4242")
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
 }
 
 func TestLighthouse_reload(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
-	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
-	assert.NoError(t, err)
-
-	nc := map[interface{}]interface{}{
-		"static_host_map": map[interface{}]interface{}{
-			"10.128.0.2": []interface{}{"1.1.1.1:4242"},
+	c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true}
+	c.Settings["listen"] = map[string]any{"port": 4242}
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	require.NoError(t, err)
+
+	nc := map[string]any{
+		"static_host_map": map[string]any{
+			"10.128.0.2": []any{"1.1.1.1:4242"},
 		},
 	}
 	rc, err := yaml.Marshal(nc)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 	c.ReloadConfigString(string(rc))
 
 	err = lh.reload(c, false)
-	assert.NoError(t, err)
+	require.NoError(t, err)
 }
 
-func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostQuery,
-		Details: &NebulaMetaDetails{
-			VpnIp: uint32(queryVpnIp),
-		},
+		Type:    NebulaMeta_HostQuery,
+		Details: &NebulaMetaDetails{},
+	}
+
+	if queryVpnIp.Is4() {
+		bip := queryVpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
 	}
 
 	b, err := req.Marshal()
@@ -302,21 +326,29 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
 	w := &testEncWriter{
 		metaFilter: &filter,
 	}
-	lhh.HandleRequest(fromAddr, myVpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
 	return w.lastReply
 }
 
-func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostUpdateNotification,
-		Details: &NebulaMetaDetails{
-			VpnIp:       uint32(vpnIp),
-			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
-		},
+		Type:    NebulaMeta_HostUpdateNotification,
+		Details: &NebulaMetaDetails{},
 	}
 
-	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
+	if vpnIp.Is4() {
+		bip := vpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
+	}
+
+	for _, v := range addrs {
+		if v.Addr().Is4() {
+			req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
+		} else {
+			req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
+		}
 	}
 
 	b, err := req.Marshal()
@@ -325,96 +357,25 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
 	}
 
 	w := &testEncWriter{}
-	lhh.HandleRequest(fromAddr, vpnIp, b, w)
-}
-
-//TODO: this is a RemoteList test
-//func Test_lhRemoteAllowList(t *testing.T) {
-//	l := NewLogger()
-//	c := NewConfig(l)
-//	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
-//		"10.20.0.0/12": false,
-//	}
-//	allowList, err := c.GetAllowList("remoteallowlist", false)
-//	assert.Nil(t, err)
-//
-//	lh1 := "10.128.0.2"
-//	lh1IP := net.ParseIP(lh1)
-//
-//	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-//
-//	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-//	lh.SetRemoteAllowList(allowList)
-//
-//	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
-//	remote1IP := net.ParseIP("10.20.0.3")
-//	remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
-//	remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
-//	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
-//	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
-//
-//	// Make sure a good ip enters the cache and addrMap
-//	remote2IP := net.ParseIP("10.128.0.3")
-//	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
-//
-//	// Another good ip gets into the cache, ordering is inverted
-//	remote3IP := net.ParseIP("10.128.0.4")
-//	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
-//
-//	// If we exceed the length limit we should only have the most recent addresses
-//	addedAddrs := []*udpAddr{}
-//	for i := 0; i < 11; i++ {
-//		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
-//		lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
-//		// The first entry here is a duplicate, don't add it to the assert list
-//		if i != 0 {
-//			addedAddrs = append(addedAddrs, remoteUDPAddr)
-//		}
-//	}
-//
-//	// We should only have the last 10 of what we tried to add
-//	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
-//	assertUdpAddrInArray(
-//		t,
-//		lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
-//		addedAddrs[0],
-//		addedAddrs[1],
-//		addedAddrs[2],
-//		addedAddrs[3],
-//		addedAddrs[4],
-//		addedAddrs[5],
-//		addedAddrs[6],
-//		addedAddrs[7],
-//		addedAddrs[8],
-//		addedAddrs[9],
-//	)
-//}
-
-func Test_ipMaskContains(t *testing.T) {
-	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
-	assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
+	lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
 }
 
 type testLhReply struct {
 	nebType    header.MessageType
 	nebSubType header.MessageSubType
-	vpnIp      iputil.VpnIp
+	vpnIp      netip.Addr
 	msg        *NebulaMeta
 }
 
 type testEncWriter struct {
-	lastReply  testLhReply
-	metaFilter *NebulaMeta_MessageType
+	lastReply       testLhReply
+	metaFilter      *NebulaMeta_MessageType
+	protocolVersion cert.Version
 }
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 }
-func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
+func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
 }
 
 func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
@@ -424,7 +385,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 		tw.lastReply = testLhReply{
 			nebType:    t,
 			nebSubType: st,
-			vpnIp:      hostinfo.vpnIp,
+			vpnIp:      hostinfo.vpnAddrs[0],
 			msg:        msg,
 		}
 	}
@@ -434,7 +395,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 }
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -451,36 +412,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	}
 }
 
-// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
-	if !assert.Len(t, have, len(want)) {
-		return
-	}
+func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return nil
+}
 
-	for k, w := range want {
-		if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
-			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
-		}
-	}
+func (tw *testEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: tw.protocolVersion}
 }
 
-// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
+// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
+func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 		return
 	}
 
 	for k, w := range want {
-		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
-			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+		h := protoV4AddrPortToNetAddrPort(have[k])
+		if !(h == w) {
+			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 	}
 }
 
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
-	addrs := make([]*udp.Addr, len(ips))
-	for k, v := range ips {
-		addrs[k] = NewUDPAddrFromLH4(v)
-	}
-	return addrs
+func Test_findNetworkUnion(t *testing.T) {
+	var out netip.Addr
+	var ok bool
+
+	tenDot := netip.MustParsePrefix("10.0.0.0/8")
+	oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
+	fe80 := netip.MustParsePrefix("fe80::/8")
+	fc00 := netip.MustParsePrefix("fc00::/7")
+
+	a1 := netip.MustParseAddr("10.0.0.1")
+	afe81 := netip.MustParseAddr("fe80::1")
+
+	//simple
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed lengths
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed family
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//ordering
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//some mismatches
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//falsey cases
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
 }

+ 19 - 34
main.go

@@ -2,9 +2,9 @@ package nebula
 
 import (
 	"context"
-	"encoding/binary"
 	"fmt"
 	"net"
+	"net/netip"
 	"time"
 
 	"github.com/sirupsen/logrus"
@@ -13,10 +13,10 @@ import (
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
-	"gopkg.in/yaml.v2"
+	"gopkg.in/yaml.v3"
 )
 
-type m map[string]interface{}
+type m = map[string]any
 
 func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logger, deviceFactory overlay.DeviceFactory) (retcon *Control, reterr error) {
 	ctx, cancel := context.WithCancel(context.Background())
@@ -60,16 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
 
-	certificate := pki.GetCertState().Certificate
-	fw, err := NewFirewallFromConfig(l, certificate, c)
+	fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
-	// TODO: make sure mask is 4 bytes
-	tunCidr := certificate.Details.Ips[0]
-
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			deviceFactory = overlay.NewDeviceFromConfig
 		}
 
-		tun, err = deviceFactory(c, l, tunCidr, routines)
+		tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
@@ -150,21 +146,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	if !configTest {
 		rawListenHost := c.GetString("listen.host", "0.0.0.0")
-		var listenHost *net.IPAddr
+		var listenHost netip.Addr
 		if rawListenHost == "[::]" {
 			// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
-			listenHost = &net.IPAddr{IP: net.IPv6zero}
+			listenHost = netip.IPv6Unspecified()
 
 		} else {
-			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
+			ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
 			if err != nil {
 				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
 			}
+			if len(ips) == 0 {
+				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
+			}
+			listenHost = ips[0].Unmap()
 		}
 
 		for i := 0; i < routines; i++ {
-			l.Infof("listening %q %d", listenHost.IP, port)
-			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
+			l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
+			udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
@@ -178,14 +178,14 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 				if err != nil {
 					return nil, util.NewContextualError("Failed to get listening port", nil, err)
 				}
-				port = int(uPort.Port)
+				port = int(uPort.Port())
 			}
 		}
 	}
 
-	hostMap := NewHostMapFromConfig(l, tunCidr, c)
+	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
@@ -201,7 +201,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	handshakeConfig := HandshakeConfig{
 		tryInterval:   c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
-		retries:       c.GetInt("handshakes.retries", DefaultHandshakeRetries),
+		retries:       int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
 		triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
 		useRelays:     useRelays,
 
@@ -228,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		pki:                     pki,
-		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
@@ -250,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		l:                     l,
 	}
 
-	switch ifConfig.Cipher {
-	case "aes":
-		noiseEndianness = binary.BigEndian
-	case "chachapoly":
-		noiseEndianness = binary.LittleEndian
-	default:
-		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
-	}
-
 	var ifce *Interface
 	if !configTest {
 		ifce, err = NewInterface(ctx, ifConfig)
@@ -266,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			return nil, fmt.Errorf("failed to initialize interface: %s", err)
 		}
 
-		// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
-		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
 		lightHouse.ifce = ifce
 
@@ -279,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		go handshakeManager.Run(ctx)
 	}
 
-	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
-	// a context so that they can exit when the context is Done.
 	statsStart, err := startStats(l, c, buildVersion, configTest)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
@@ -290,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, nil
 	}
 
-	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
 	attachCommands(l, c, ssh, ifce)
@@ -299,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	var dnsStart func()
 	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, c)
+		dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
 	}
 
 	return &Control{

+ 0 - 2
message_metrics.go

@@ -7,8 +7,6 @@ import (
 	"github.com/slackhq/nebula/header"
 )
 
-//TODO: this can probably move into the header package
-
 type MessageMetrics struct {
 	rx [][]metrics.Counter
 	tx [][]metrics.Counter

+ 0 - 18
metadata.go

@@ -1,18 +0,0 @@
-package nebula
-
-/*
-
-import (
-	proto "google.golang.org/protobuf/proto"
-)
-
-func HandleMetaProto(p []byte) {
-	m := &NebulaMeta{}
-	err := proto.Unmarshal(p, m)
-	if err != nil {
-		l.Debugf("problem unmarshaling meta message: %s", err)
-	}
-	//fmt.Println(m)
-}
-
-*/

File diff suppressed because it is too large
+ 467 - 171
nebula.pb.go


+ 23 - 9
nebula.proto

@@ -23,19 +23,28 @@ message NebulaMeta {
 }
 
 message NebulaMetaDetails {
-  uint32 VpnIp = 1;
-  repeated Ip4AndPort Ip4AndPorts = 2;
-  repeated Ip6AndPort Ip6AndPorts = 4;
-  repeated uint32 RelayVpnIp = 5;
+  uint32 OldVpnAddr = 1 [deprecated = true];
+  Addr VpnAddr = 6;
+
+  repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
+  repeated Addr RelayVpnAddrs = 7;
+
+  repeated V4AddrPort V4AddrPorts = 2;
+  repeated V6AddrPort V6AddrPorts = 4;
   uint32 counter = 3;
 }
 
-message Ip4AndPort {
-  uint32 Ip = 1;
+message Addr {
+  uint64 Hi = 1;
+  uint64 Lo = 2;
+}
+
+message V4AddrPort {
+  uint32 Addr = 1;
   uint32 Port = 2;
 }
 
-message Ip6AndPort {
+message V6AddrPort {
   uint64 Hi = 1;
   uint64 Lo = 2;
   uint32 Port = 3;
@@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
   uint32 ResponderIndex = 3;
   uint64 Cookie = 4;
   uint64 Time = 5;
+  uint32 CertVersion = 8;
   // reserved for WIP multiport
   reserved 6, 7;
 }
@@ -76,6 +86,10 @@ message NebulaControl {
 
   uint32 InitiatorRelayIndex = 2;
   uint32 ResponderRelayIndex = 3;
-  uint32 RelayToIp = 4;
-  uint32 RelayFromIp = 5;
+
+  uint32 OldRelayToAddr = 4 [deprecated = true];
+  uint32 OldRelayFromAddr = 5 [deprecated = true];
+
+  Addr RelayToAddr = 6;
+  Addr RelayFromAddr = 7;
 }

Some files were not shown because too many files changed in this diff