Browse Source

Merge remote-tracking branch 'origin/master' into multiport

Wade Simmons 4 months ago
parent
commit
f36db374ac
100 changed files with 9522 additions and 5597 deletions
  1. 1 1
      .github/workflows/gofmt.yml
  2. 3 3
      .github/workflows/release.yml
  3. 3 0
      .github/workflows/smoke-extra.yml
  4. 1 1
      .github/workflows/smoke.yml
  5. 18 18
      .github/workflows/smoke/smoke-vagrant.sh
  6. 22 4
      .github/workflows/test.yml
  7. 3 1
      .gitignore
  8. 14 5
      Makefile
  9. 1 1
      README.md
  10. 25 12
      allow_list.go
  11. 34 25
      calculated_remote.go
  12. 61 5
      calculated_remote_test.go
  13. 1 1
      cert/Makefile
  14. 15 4
      cert/README.md
  15. 52 0
      cert/asn1.go
  16. 0 140
      cert/ca.go
  17. 296 0
      cert/ca_pool.go
  18. 559 0
      cert/ca_pool_test.go
  19. 104 968
      cert/cert.go
  20. 0 1230
      cert/cert_test.go
  21. 489 0
      cert/cert_v1.go
  22. 111 111
      cert/cert_v1.pb.go
  23. 0 0
      cert/cert_v1.proto
  24. 218 0
      cert/cert_v1_test.go
  25. 37 0
      cert/cert_v2.asn1
  26. 730 0
      cert/cert_v2.go
  27. 267 0
      cert/cert_v2_test.go
  28. 159 2
      cert/crypto.go
  29. 87 0
      cert/crypto_test.go
  30. 41 6
      cert/errors.go
  31. 141 0
      cert/helper_test.go
  32. 161 0
      cert/pem.go
  33. 292 0
      cert/pem_test.go
  34. 167 0
      cert/sign.go
  35. 90 0
      cert/sign_test.go
  36. 138 0
      cert_test/cert.go
  37. 137 73
      cmd/nebula-cert/ca.go
  38. 36 30
      cmd/nebula-cert/ca_test.go
  39. 49 20
      cmd/nebula-cert/keygen.go
  40. 6 5
      cmd/nebula-cert/keygen_test.go
  41. 10 3
      cmd/nebula-cert/main_test.go
  42. 15 0
      cmd/nebula-cert/p11_cgo.go
  43. 16 0
      cmd/nebula-cert/p11_stub.go
  44. 14 9
      cmd/nebula-cert/print.go
  45. 144 21
      cmd/nebula-cert/print_test.go
  46. 238 110
      cmd/nebula-cert/sign.go
  47. 74 75
      cmd/nebula-cert/sign_test.go
  48. 26 15
      cmd/nebula-cert/verify.go
  49. 13 30
      cmd/nebula-cert/verify_test.go
  50. 0 3
      config/config_test.go
  51. 61 36
      connection_manager.go
  52. 162 61
      connection_manager_test.go
  53. 32 23
      connection_state.go
  54. 37 36
      control.go
  55. 22 36
      control_test.go
  56. 43 22
      control_tester.go
  57. 79 39
      dns_server.go
  58. 20 5
      dns_server_test.go
  59. 389 168
      e2e/handshakes_test.go
  60. 0 125
      e2e/helpers.go
  61. 75 38
      e2e/helpers_test.go
  62. 5 4
      e2e/router/hostmap.go
  63. 44 22
      e2e/router/router.go
  64. 12 5
      examples/config.yml
  65. 96 94
      firewall.go
  66. 11 10
      firewall/packet.go
  67. 136 195
      firewall_test.go
  68. 21 19
      go.mod
  69. 42 37
      go.sum
  70. 223 104
      handshake_ix.go
  71. 181 124
      handshake_manager.go
  72. 18 11
      handshake_manager_test.go
  73. 171 104
      hostmap.go
  74. 23 33
      hostmap_test.go
  75. 2 2
      hostmap_tester.go
  76. 31 27
      inside.go
  77. 84 77
      interface.go
  78. 0 2
      iputil/packet.go
  79. 364 240
      lighthouse.go
  80. 173 146
      lighthouse_test.go
  81. 5 34
      main.go
  82. 0 2
      message_metrics.go
  83. 560 176
      nebula.pb.go
  84. 23 9
      nebula.proto
  85. 50 0
      noiseutil/pkcs11.go
  86. 156 137
      outside.go
  87. 525 17
      outside_test.go
  88. 1 1
      overlay/device.go
  89. 22 12
      overlay/route.go
  90. 39 33
      overlay/route_test.go
  91. 9 9
      overlay/tun.go
  92. 11 11
      overlay/tun_android.go
  93. 208 215
      overlay/tun_darwin.go
  94. 8 8
      overlay/tun_disabled.go
  95. 25 15
      overlay/tun_freebsd.go
  96. 10 10
      overlay/tun_ios.go
  97. 120 78
      overlay/tun_linux.go
  98. 28 17
      overlay/tun_netbsd.go
  99. 29 19
      overlay/tun_openbsd.go
  100. 17 17
      overlay/tun_tester.go

+ 1 - 1
.github/workflows/gofmt.yml

@@ -18,7 +18,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
 
     - name: Install goimports

+ 3 - 3
.github/workflows/release.yml

@@ -14,7 +14,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.22'
+          go-version: '1.23'
           check-latest: true
 
       - name: Build
@@ -37,7 +37,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.22'
+          go-version: '1.23'
           check-latest: true
 
       - name: Build
@@ -70,7 +70,7 @@ jobs:
 
       - uses: actions/setup-go@v5
         with:
-          go-version: '1.22'
+          go-version: '1.23'
           check-latest: true
 
       - name: Import certificates

+ 3 - 0
.github/workflows/smoke-extra.yml

@@ -27,6 +27,9 @@ jobs:
         go-version-file: 'go.mod'
         check-latest: true
 
+    - name: add hashicorp source
+      run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list
+
     - name: install vagrant
       run: sudo apt-get update && sudo apt-get install -y vagrant virtualbox
 

+ 1 - 1
.github/workflows/smoke.yml

@@ -22,7 +22,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
 
     - name: build

+ 18 - 18
.github/workflows/smoke/smoke-vagrant.sh

@@ -29,13 +29,13 @@ docker run --name lighthouse1 --rm "$CONTAINER" -config lighthouse1.yml -test
 docker run --name host2 --rm "$CONTAINER" -config host2.yml -test
 
 vagrant up
-vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test"
+vagrant ssh -c "cd /nebula && /nebula/$1-nebula -config host3.yml -test" -- -T
 
 docker run --name lighthouse1 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config lighthouse1.yml 2>&1 | tee logs/lighthouse1 | sed -u 's/^/  [lighthouse1]  /' &
 sleep 1
 docker run --name host2 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN --rm "$CONTAINER" -config host2.yml 2>&1 | tee logs/host2 | sed -u 's/^/  [host2]  /' &
 sleep 1
-vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" &
+vagrant ssh -c "cd /nebula && sudo sh -c 'echo \$\$ >/nebula/pid && exec /nebula/$1-nebula -config host3.yml'" 2>&1 -- -T | tee logs/host3 | sed -u 's/^/  [host3]  /' &
 sleep 15
 
 # grab tcpdump pcaps for debugging
@@ -46,8 +46,8 @@ docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host
 # vagrant ssh -c "tcpdump -i nebula1 -q -w - -U" 2>logs/host3.inside.log >logs/host3.inside.pcap &
 # vagrant ssh -c "tcpdump -i eth0 -q -w - -U" 2>logs/host3.outside.log >logs/host3.outside.pcap &
 
-docker exec host2 ncat -nklv 0.0.0.0 2000 &
-vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
+#docker exec host2 ncat -nklv 0.0.0.0 2000 &
+#vagrant ssh -c "ncat -nklv 0.0.0.0 2000" &
 #docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
 #vagrant ssh -c "ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000" &
 
@@ -68,11 +68,11 @@ docker exec host2 ping -c1 192.168.100.1
 # Should fail because not allowed by host3 inbound firewall
 ! docker exec host2 ping -c1 192.168.100.3 -w5 || exit 1
 
-set +x
-echo
-echo " *** Testing ncat from host2"
-echo
-set -x
+#set +x
+#echo
+#echo " *** Testing ncat from host2"
+#echo
+#set -x
 # Should fail because not allowed by host3 inbound firewall
 #! docker exec host2 ncat -nzv -w5 192.168.100.3 2000 || exit 1
 #! docker exec host2 ncat -nzuv -w5 192.168.100.3 3000 | grep -q host3 || exit 1
@@ -82,18 +82,18 @@ echo
 echo " *** Testing ping from host3"
 echo
 set -x
-vagrant ssh -c "ping -c1 192.168.100.1"
-vagrant ssh -c "ping -c1 192.168.100.2"
-
-set +x
-echo
-echo " *** Testing ncat from host3"
-echo
-set -x
+vagrant ssh -c "ping -c1 192.168.100.1" -- -T
+vagrant ssh -c "ping -c1 192.168.100.2" -- -T
+
+#set +x
+#echo
+#echo " *** Testing ncat from host3"
+#echo
+#set -x
 #vagrant ssh -c "ncat -nzv -w5 192.168.100.2 2000"
 #vagrant ssh -c "ncat -nzuv -w5 192.168.100.2 3000" | grep -q host2
 
-vagrant ssh -c "sudo xargs kill </nebula/pid"
+vagrant ssh -c "sudo xargs kill </nebula/pid" -- -T
 docker exec host2 sh -c 'kill 1'
 docker exec lighthouse1 sh -c 'kill 1'
 sleep 1

+ 22 - 4
.github/workflows/test.yml

@@ -22,7 +22,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
 
     - name: Build
@@ -55,7 +55,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
 
     - name: Build
@@ -65,7 +65,25 @@ jobs:
       run: make test-boringcrypto
 
     - name: End 2 end
-      run: make e2evv GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+      run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0"
+
+  test-linux-pkcs11:
+    name: Build and test on linux with pkcs11
+    runs-on: ubuntu-latest
+    steps:
+
+    - uses: actions/checkout@v4
+
+    - uses: actions/setup-go@v5
+      with:
+        go-version: '1.22'
+        check-latest: true
+
+    - name: Build
+      run: make bin-pkcs11
+
+    - name: Test
+      run: make test-pkcs11
 
   test:
     name: Build and test on ${{ matrix.os }}
@@ -79,7 +97,7 @@ jobs:
 
     - uses: actions/setup-go@v5
       with:
-        go-version: '1.22'
+        go-version: '1.23'
         check-latest: true
 
     - name: Build nebula

+ 3 - 1
.gitignore

@@ -5,7 +5,8 @@
 /nebula-darwin
 /nebula.exe
 /nebula-cert.exe
-/coverage.out
+**/coverage.out
+**/cover.out
 /cpu.pprof
 /build
 /*.tar.gz
@@ -13,5 +14,6 @@
 **.crt
 **.key
 **.pem
+**.pub
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.key
 !/examples/quickstart-vagrant/ansible/roles/nebula/files/vagrant-test-ca.crt

+ 14 - 5
Makefile

@@ -40,7 +40,7 @@ ALL_LINUX = linux-amd64 \
 	linux-mips64le \
 	linux-mips-softfloat \
 	linux-riscv64 \
-        linux-loong64
+	linux-loong64
 
 ALL_FREEBSD = freebsd-amd64 \
 	freebsd-arm64
@@ -63,7 +63,7 @@ ALL = $(ALL_LINUX) \
 e2e:
 	$(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e
 
-e2ev: TEST_FLAGS = -v
+e2ev: TEST_FLAGS += -v
 e2ev: e2e
 
 e2evv: TEST_ENV += TEST_LOGS=1
@@ -96,7 +96,7 @@ release-netbsd: $(ALL_NETBSD:%=build/nebula-%.tar.gz)
 
 release-boringcrypto: build/nebula-linux-$(shell go env GOARCH)-boringcrypto.tar.gz
 
-BUILD_ARGS = -trimpath
+BUILD_ARGS += -trimpath
 
 bin-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe
 	mv $? .
@@ -116,6 +116,10 @@ bin-freebsd-arm64: build/freebsd-arm64/nebula build/freebsd-arm64/nebula-cert
 bin-boringcrypto: build/linux-$(shell go env GOARCH)-boringcrypto/nebula build/linux-$(shell go env GOARCH)-boringcrypto/nebula-cert
 	mv $? .
 
+bin-pkcs11: BUILD_ARGS += -tags pkcs11
+bin-pkcs11: CGO_ENABLED = 1
+bin-pkcs11: bin
+
 bin:
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula${NEBULA_CMD_SUFFIX} ${NEBULA_CMD_PATH}
 	go build $(BUILD_ARGS) -ldflags "$(LDFLAGS)" -o ./nebula-cert${NEBULA_CMD_SUFFIX} ./cmd/nebula-cert
@@ -133,6 +137,8 @@ build/linux-mips-softfloat/%: LDFLAGS += -s -w
 # boringcrypto
 build/linux-amd64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
 build/linux-arm64-boringcrypto/%: GOENV += GOEXPERIMENT=boringcrypto CGO_ENABLED=1
+build/linux-amd64-boringcrypto/%: LDFLAGS += -checklinkname=0
+build/linux-arm64-boringcrypto/%: LDFLAGS += -checklinkname=0
 
 build/%/nebula: .FORCE
 	GOOS=$(firstword $(subst -, , $*)) \
@@ -166,7 +172,10 @@ test:
 	go test -v ./...
 
 test-boringcrypto:
-	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -v ./...
+	GOEXPERIMENT=boringcrypto CGO_ENABLED=1 go test -ldflags "-checklinkname=0" -v ./...
+
+test-pkcs11:
+	CGO_ENABLED=1 go test -v -tags pkcs11 ./...
 
 test-cov-html:
 	go test -coverprofile=coverage.out
@@ -189,7 +198,7 @@ bench-cpu-long:
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 
-proto: nebula.pb.go cert/cert.pb.go
+proto: nebula.pb.go cert/cert_v1.pb.go
 
 nebula.pb.go: nebula.proto .FORCE
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster

+ 1 - 1
README.md

@@ -12,7 +12,7 @@ Further documentation can be found [here](https://nebula.defined.net/docs/).
 
 You can read more about Nebula [here](https://medium.com/p/884110a5579).
 
-You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/enQtOTA5MDI4NDg3MTg4LTkwY2EwNTI4NzQyMzc0M2ZlODBjNWI3NTY1MzhiOThiMmZlZjVkMTI0NGY4YTMyNjUwMWEyNzNkZTJmYzQxOGU).
+You can also join the NebulaOSS Slack group [here](https://join.slack.com/t/nebulaoss/shared_invite/zt-2xqe6e7vn-k_KGi8s13nsr7cvHVvHvuQ).
 
 ## Supported Platforms
 

+ 25 - 12
allow_list.go

@@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
 
-		// TODO: should we error on duplicate CIDRs in the config?
 		tree.Insert(ipNet, value)
 
 		maskBits := ipNet.Bits()
@@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
 	return remoteAllowRanges, nil
 }
 
-func (al *AllowList) Allow(ip netip.Addr) bool {
+func (al *AllowList) Allow(addr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
 
-	result, _ := al.cidrTree.Lookup(ip)
+	result, _ := al.cidrTree.Lookup(addr)
 	return result
 }
 
-func (al *LocalAllowList) Allow(ip netip.Addr) bool {
+func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
 func (al *LocalAllowList) AllowName(name string) bool {
@@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 }
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
+func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(vpnAddr)
 }
 
-func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
-	if !al.getInsideAllowList(vpnIp).Allow(ip) {
+func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
+	if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
 		return false
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
+func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
+	if !al.AllowList.Allow(udpAddr) {
+		return false
+	}
+
+	for _, vpnAddr := range vpnAddrs {
+		if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
-		inside, ok := al.insideAllowLists.Lookup(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnAddr)
 		if ok {
 			return inside
 		}

+ 34 - 25
calculated_remote.go

@@ -21,7 +21,11 @@ type calculatedRemote struct {
 	port  uint32
 }
 
-func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
+		return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
+	}
+
 	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
@@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 
-func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
-	// Combine the masked bytes of the "mask" IP with the unmasked bytes
-	// of the overlay IP
-	if c.ipNet.Addr().Is4() {
-		return c.apply4(ip)
-	}
-	return c.apply6(ip)
-}
-
-func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK this can be less crappy
+func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
+	// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
 	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
 	mask := binary.BigEndian.Uint32(maskb[:])
 
 	b := c.mask.Addr().As4()
-	maskIp := binary.BigEndian.Uint32(b[:])
+	maskAddr := binary.BigEndian.Uint32(b[:])
 
-	b = ip.As4()
-	intIp := binary.BigEndian.Uint32(b[:])
+	b = addr.As4()
+	intAddr := binary.BigEndian.Uint32(b[:])
 
-	return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+	return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
 }
 
-func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK
-	panic("Can not calculate ipv6 remote addresses")
+func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
+	mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	maskAddr := c.mask.Addr().As16()
+	calcAddr := addr.As16()
+
+	ap := V6AddrPort{Port: c.port}
+
+	maskb := binary.BigEndian.Uint64(mask[:8])
+	maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
+	calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
+	ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	maskb = binary.BigEndian.Uint64(mask[8:])
+	maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
+	calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
+	ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	return &ap
 }
 
 func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 
-		//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
-		entry, err := newCalculatedRemotesListFromConfig(rawValue)
+		entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
@@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 	return calculatedRemotes, nil
 }
 
-func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
+func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
 	rawList, ok := raw.([]any)
 	if !ok {
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 
 	var l []*calculatedRemote
 	for _, e := range rawList {
-		c, err := newCalculatedRemotesEntryFromConfig(e)
+		c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
 		if err != nil {
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 		}
@@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 	return l, nil
 }
 
-func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
+func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
 	rawMap, ok := raw.(map[any]any)
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
@@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 
-	return newCalculatedRemote(maskCidr, port)
+	return newCalculatedRemote(cidr, maskCidr, port)
 }

+ 61 - 5
calculated_remote_test.go

@@ -9,10 +9,9 @@ import (
 )
 
 func TestCalculatedRemoteApply(t *testing.T) {
-	ipNet, err := netip.ParsePrefix("192.168.1.0/24")
-	require.NoError(t, err)
-
-	c, err := newCalculatedRemote(ipNet, 4242)
+	// Test v4 addresses
+	ipNet := netip.MustParsePrefix("192.168.1.0/24")
+	c, err := newCalculatedRemote(ipNet, ipNet, 4242)
 	require.NoError(t, err)
 
 	input, err := netip.ParseAddr("10.0.10.182")
@@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	expected, err := netip.ParseAddr("192.168.1.182")
 	assert.NoError(t, err)
 
-	assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
+	assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
+
+	// Test v6 addresses
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+}
+
+func Test_newCalculatedRemote(t *testing.T) {
+	c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
 }

+ 1 - 1
cert/Makefile

@@ -1,7 +1,7 @@
 GO111MODULE = on
 export GO111MODULE
 
-cert.pb.go: cert.proto .FORCE
+cert_v1.pb.go: cert_v1.proto .FORCE
 	go build google.golang.org/protobuf/cmd/protoc-gen-go
 	PATH="$(CURDIR):$(PATH)" protoc --go_out=. --go_opt=paths=source_relative $<
 	rm protoc-gen-go

+ 15 - 4
cert/README.md

@@ -2,14 +2,25 @@
 
 This is a library for interacting with `nebula` style certificates and authorities.
 
-A `protobuf` definition of the certificate format is also included
+There are now 2 versions of `nebula` certificates:
 
-### Compiling the protobuf definition
+## v1
 
-Make sure you have `protoc` installed.
+This version is deprecated.
+
+A `protobuf` definition of the certificate format is included at `cert_v1.proto`
+
+To compile the definition you will need `protoc` installed.
 
 To compile for `go` with the same version of protobuf specified in go.mod:
 
 ```bash
-make
+make proto
 ```
+
+## v2
+
+This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
+future certificate changes better than v1.
+
+`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

+ 52 - 0
cert/asn1.go

@@ -0,0 +1,52 @@
+package cert
+
+import (
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+)
+
+// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
+// https://github.com/golang/go/issues/64811#issuecomment-1944446920
+func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0] > 0
+		return true
+	}
+
+	return false
+}
+
+// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
+// Similar issue as with readOptionalASN1Boolean
+func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0]
+		return true
+	}
+
+	return false
+}

+ 0 - 140
cert/ca.go

@@ -1,140 +0,0 @@
-package cert
-
-import (
-	"errors"
-	"fmt"
-	"strings"
-	"time"
-)
-
-type NebulaCAPool struct {
-	CAs           map[string]*NebulaCertificate
-	certBlocklist map[string]struct{}
-}
-
-// NewCAPool creates a CAPool
-func NewCAPool() *NebulaCAPool {
-	ca := NebulaCAPool{
-		CAs:           make(map[string]*NebulaCertificate),
-		certBlocklist: make(map[string]struct{}),
-	}
-
-	return &ca
-}
-
-// NewCAPoolFromBytes will create a new CA pool from the provided
-// input bytes, which must be a PEM-encoded set of nebula certificates.
-// If the pool contains any expired certificates, an ErrExpired will be
-// returned along with the pool. The caller must handle any such errors.
-func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
-	pool := NewCAPool()
-	var err error
-	var expired bool
-	for {
-		caPEMs, err = pool.AddCACertificate(caPEMs)
-		if errors.Is(err, ErrExpired) {
-			expired = true
-			err = nil
-		}
-		if err != nil {
-			return nil, err
-		}
-		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
-			break
-		}
-	}
-
-	if expired {
-		return pool, ErrExpired
-	}
-
-	return pool, nil
-}
-
-// AddCACertificate verifies a Nebula CA certificate and adds it to the pool
-// Only the first pem encoded object will be consumed, any remaining bytes are returned.
-// Parsed certificates will be verified and must be a CA
-func (ncp *NebulaCAPool) AddCACertificate(pemBytes []byte) ([]byte, error) {
-	c, pemBytes, err := UnmarshalNebulaCertificateFromPEM(pemBytes)
-	if err != nil {
-		return pemBytes, err
-	}
-
-	if !c.Details.IsCA {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA)
-	}
-
-	if !c.CheckSignature(c.Details.PublicKey) {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned)
-	}
-
-	sum, err := c.Sha256Sum()
-	if err != nil {
-		return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name)
-	}
-
-	ncp.CAs[sum] = c
-	if c.Expired(time.Now()) {
-		return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired)
-	}
-
-	return pemBytes, nil
-}
-
-// BlocklistFingerprint adds a cert fingerprint to the blocklist
-func (ncp *NebulaCAPool) BlocklistFingerprint(f string) {
-	ncp.certBlocklist[f] = struct{}{}
-}
-
-// ResetCertBlocklist removes all previously blocklisted cert fingerprints
-func (ncp *NebulaCAPool) ResetCertBlocklist() {
-	ncp.certBlocklist = make(map[string]struct{})
-}
-
-// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated
-// automatically if you manually change any fields in the NebulaCertificate.
-func (ncp *NebulaCAPool) IsBlocklisted(c *NebulaCertificate) bool {
-	return ncp.isBlocklistedWithCache(c, false)
-}
-
-// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted
-func (ncp *NebulaCAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool {
-	h, err := c.sha256SumWithCache(useCache)
-	if err != nil {
-		return true
-	}
-
-	if _, ok := ncp.certBlocklist[h]; ok {
-		return true
-	}
-
-	return false
-}
-
-// GetCAForCert attempts to return the signing certificate for the provided certificate.
-// No signature validation is performed
-func (ncp *NebulaCAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) {
-	if c.Details.Issuer == "" {
-		return nil, fmt.Errorf("no issuer in certificate")
-	}
-
-	signer, ok := ncp.CAs[c.Details.Issuer]
-	if ok {
-		return signer, nil
-	}
-
-	return nil, fmt.Errorf("could not find ca for the certificate")
-}
-
-// GetFingerprints returns an array of trusted CA fingerprints
-func (ncp *NebulaCAPool) GetFingerprints() []string {
-	fp := make([]string, len(ncp.CAs))
-
-	i := 0
-	for k := range ncp.CAs {
-		fp[i] = k
-		i++
-	}
-
-	return fp
-}

+ 296 - 0
cert/ca_pool.go

@@ -0,0 +1,296 @@
+package cert
+
+import (
+	"errors"
+	"fmt"
+	"net/netip"
+	"slices"
+	"strings"
+	"time"
+)
+
+type CAPool struct {
+	CAs           map[string]*CachedCertificate
+	certBlocklist map[string]struct{}
+}
+
+// NewCAPool creates an empty CAPool
+func NewCAPool() *CAPool {
+	ca := CAPool{
+		CAs:           make(map[string]*CachedCertificate),
+		certBlocklist: make(map[string]struct{}),
+	}
+
+	return &ca
+}
+
+// NewCAPoolFromPEM will create a new CA pool from the provided
+// input bytes, which must be a PEM-encoded set of nebula certificates.
+// If the pool contains any expired certificates, an ErrExpired will be
+// returned along with the pool. The caller must handle any such errors.
+func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
+	pool := NewCAPool()
+	var err error
+	var expired bool
+	for {
+		caPEMs, err = pool.AddCAFromPEM(caPEMs)
+		if errors.Is(err, ErrExpired) {
+			expired = true
+			err = nil
+		}
+		if err != nil {
+			return nil, err
+		}
+		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
+			break
+		}
+	}
+
+	if expired {
+		return pool, ErrExpired
+	}
+
+	return pool, nil
+}
+
+// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool.
+// Only the first pem encoded object will be consumed, any remaining bytes are returned.
+// Parsed certificates will be verified and must be a CA
+func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) {
+	c, pemBytes, err := UnmarshalCertificateFromPEM(pemBytes)
+	if err != nil {
+		return pemBytes, err
+	}
+
+	err = ncp.AddCA(c)
+	if err != nil {
+		return pemBytes, err
+	}
+
+	return pemBytes, nil
+}
+
+// AddCA verifies a Nebula CA certificate and adds it to the pool.
+func (ncp *CAPool) AddCA(c Certificate) error {
+	if !c.IsCA() {
+		return fmt.Errorf("%s: %w", c.Name(), ErrNotCA)
+	}
+
+	if !c.CheckSignature(c.PublicKey()) {
+		return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned)
+	}
+
+	sum, err := c.Fingerprint()
+	if err != nil {
+		return fmt.Errorf("could not calculate fingerprint for provided CA; error: %w; %s", err, c.Name())
+	}
+
+	cc := &CachedCertificate{
+		Certificate:    c,
+		Fingerprint:    sum,
+		InvertedGroups: make(map[string]struct{}),
+	}
+
+	for _, g := range c.Groups() {
+		cc.InvertedGroups[g] = struct{}{}
+	}
+
+	ncp.CAs[sum] = cc
+
+	if c.Expired(time.Now()) {
+		return fmt.Errorf("%s: %w", c.Name(), ErrExpired)
+	}
+
+	return nil
+}
+
+// BlocklistFingerprint adds a cert fingerprint to the blocklist
+func (ncp *CAPool) BlocklistFingerprint(f string) {
+	ncp.certBlocklist[f] = struct{}{}
+}
+
+// ResetCertBlocklist removes all previously blocklisted cert fingerprints
+func (ncp *CAPool) ResetCertBlocklist() {
+	ncp.certBlocklist = make(map[string]struct{})
+}
+
+// IsBlocklisted tests the provided fingerprint against the pools blocklist.
+// Returns true if the fingerprint is blocked.
+func (ncp *CAPool) IsBlocklisted(fingerprint string) bool {
+	if _, ok := ncp.certBlocklist[fingerprint]; ok {
+		return true
+	}
+
+	return false
+}
+
+// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool.
+// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts
+// to increase performance.
+func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) {
+	if c == nil {
+		return nil, fmt.Errorf("no certificate")
+	}
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, fmt.Errorf("could not calculate fingerprint to verify: %w", err)
+	}
+
+	signer, err := ncp.verify(c, now, fp, "")
+	if err != nil {
+		return nil, err
+	}
+
+	cc := CachedCertificate{
+		Certificate:       c,
+		InvertedGroups:    make(map[string]struct{}),
+		Fingerprint:       fp,
+		signerFingerprint: signer.Fingerprint,
+	}
+
+	for _, g := range c.Groups() {
+		cc.InvertedGroups[g] = struct{}{}
+	}
+
+	return &cc, nil
+}
+
+// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and
+// is a cheaper operation to perform as a result.
+func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error {
+	_, err := ncp.verify(c.Certificate, now, c.Fingerprint, c.signerFingerprint)
+	return err
+}
+
+func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp string) (*CachedCertificate, error) {
+	if ncp.IsBlocklisted(certFp) {
+		return nil, ErrBlockListed
+	}
+
+	signer, err := ncp.GetCAForCert(c)
+	if err != nil {
+		return nil, err
+	}
+
+	if signer.Certificate.Expired(now) {
+		return nil, ErrRootExpired
+	}
+
+	if c.Expired(now) {
+		return nil, ErrExpired
+	}
+
+	// If we are checking a cached certificate then we can bail early here
+	// Either the root is no longer trusted or everything is fine
+	if len(signerFp) > 0 {
+		if signerFp != signer.Fingerprint {
+			return nil, ErrFingerprintMismatch
+		}
+		return signer, nil
+	}
+	if !c.CheckSignature(signer.Certificate.PublicKey()) {
+		return nil, ErrSignatureMismatch
+	}
+
+	err = CheckCAConstraints(signer.Certificate, c)
+	if err != nil {
+		return nil, err
+	}
+
+	return signer, nil
+}
+
+// GetCAForCert attempts to return the signing certificate for the provided certificate.
+// No signature validation is performed
+func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
+	issuer := c.Issuer()
+	if issuer == "" {
+		return nil, fmt.Errorf("no issuer in certificate")
+	}
+
+	signer, ok := ncp.CAs[issuer]
+	if ok {
+		return signer, nil
+	}
+
+	return nil, ErrCaNotFound
+}
+
+// GetFingerprints returns an array of trusted CA fingerprints
+func (ncp *CAPool) GetFingerprints() []string {
+	fp := make([]string, len(ncp.CAs))
+
+	i := 0
+	for k := range ncp.CAs {
+		fp[i] = k
+		i++
+	}
+
+	return fp
+}
+
+// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate.
+func CheckCAConstraints(signer Certificate, sub Certificate) error {
+	return checkCAConstraints(signer, sub.NotBefore(), sub.NotAfter(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks())
+}
+
+// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested.
+func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error {
+	// Make sure this cert isn't valid after the root
+	if notAfter.After(signer.NotAfter()) {
+		return fmt.Errorf("certificate expires after signing certificate")
+	}
+
+	// Make sure this cert wasn't valid before the root
+	if notBefore.Before(signer.NotBefore()) {
+		return fmt.Errorf("certificate is valid before the signing certificate")
+	}
+
+	// If the signer has a limited set of groups make sure the cert only contains a subset
+	signerGroups := signer.Groups()
+	if len(signerGroups) > 0 {
+		for _, g := range groups {
+			if !slices.Contains(signerGroups, g) {
+				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
+			}
+		}
+	}
+
+	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
+	signingNetworks := signer.Networks()
+	if len(signingNetworks) > 0 {
+		for _, certNetwork := range networks {
+			found := false
+			for _, signingNetwork := range signingNetworks {
+				if signingNetwork.Contains(certNetwork.Addr()) && signingNetwork.Bits() <= certNetwork.Bits() {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				return fmt.Errorf("certificate contained a network assignment outside the limitations of the signing ca: %s", certNetwork.String())
+			}
+		}
+	}
+
+	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
+	signingUnsafeNetworks := signer.UnsafeNetworks()
+	if len(signingUnsafeNetworks) > 0 {
+		for _, certUnsafeNetwork := range unsafeNetworks {
+			found := false
+			for _, caNetwork := range signingUnsafeNetworks {
+				if caNetwork.Contains(certUnsafeNetwork.Addr()) && caNetwork.Bits() <= certUnsafeNetwork.Bits() {
+					found = true
+					break
+				}
+			}
+
+			if !found {
+				return fmt.Errorf("certificate contained an unsafe network assignment outside the limitations of the signing ca: %s", certUnsafeNetwork.String())
+			}
+		}
+	}
+
+	return nil
+}

+ 559 - 0
cert/ca_pool_test.go

@@ -0,0 +1,559 @@
+package cert
+
+import (
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestNewCAPoolFromBytes(t *testing.T) {
+	noNewLines := `
+# Current provisional, Remove once everything moves over to the real root.
+-----BEGIN NEBULA CERTIFICATE-----
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
+-----END NEBULA CERTIFICATE-----
+# root-ca01
+-----BEGIN NEBULA CERTIFICATE-----
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
+-----END NEBULA CERTIFICATE-----
+`
+
+	withNewLines := `
+# Current provisional, Remove once everything moves over to the real root.
+
+-----BEGIN NEBULA CERTIFICATE-----
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
+-----END NEBULA CERTIFICATE-----
+
+# root-ca01
+
+
+-----BEGIN NEBULA CERTIFICATE-----
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
+-----END NEBULA CERTIFICATE-----
+
+`
+
+	expired := `
+# expired certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
+7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
+Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
+-----END NEBULA CERTIFICATE-----
+`
+
+	p256 := `
+# p256 certificate
+-----BEGIN NEBULA CERTIFICATE-----
+CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
+k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
+75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
+-----END NEBULA CERTIFICATE-----
+`
+
+	rootCA := certificateV1{
+		details: detailsV1{
+			name: "nebula root ca",
+		},
+	}
+
+	rootCA01 := certificateV1{
+		details: detailsV1{
+			name: "nebula root ca 01",
+		},
+	}
+
+	rootCAP256 := certificateV1{
+		details: detailsV1{
+			name: "nebula P256 test",
+		},
+	}
+
+	p, err := NewCAPoolFromPEM([]byte(noNewLines))
+	assert.Nil(t, err)
+	assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+
+	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
+	assert.Nil(t, err)
+	assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+
+	// expired cert, no valid certs
+	ppp, err := NewCAPoolFromPEM([]byte(expired))
+	assert.Equal(t, ErrExpired, err)
+	assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
+
+	// expired cert, with valid certs
+	pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
+	assert.Equal(t, ErrExpired, err)
+	assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+	assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
+	assert.Equal(t, len(pppp.CAs), 3)
+
+	ppppp, err := NewCAPoolFromPEM([]byte(p256))
+	assert.Nil(t, err)
+	assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
+	assert.Equal(t, len(ppppp.CAs), 1)
+}
+
+func TestCertificateV1_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}

+ 104 - 968
cert/cert.go

@@ -1,1029 +1,165 @@
 package cert
 
 import (
-	"bytes"
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/ed25519"
-	"crypto/elliptic"
-	"crypto/rand"
-	"crypto/sha256"
-	"encoding/binary"
-	"encoding/hex"
-	"encoding/json"
-	"encoding/pem"
-	"errors"
 	"fmt"
-	"math"
-	"math/big"
-	"net"
-	"sync/atomic"
+	"net/netip"
 	"time"
-
-	"golang.org/x/crypto/curve25519"
-	"google.golang.org/protobuf/proto"
 )
 
-const publicKeyLen = 32
+type Version uint8
 
 const (
-	CertBanner                       = "NEBULA CERTIFICATE"
-	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
-	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
-	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
-	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
-	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
-
-	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
-	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
-	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
-	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+	VersionPre1 Version = 0
+	Version1    Version = 1
+	Version2    Version = 2
 )
 
-type NebulaCertificate struct {
-	Details   NebulaCertificateDetails
-	Signature []byte
+type Certificate interface {
+	// Version defines the underlying certificate structure and wire protocol version
+	// Version1 certificates are ipv4 only and uses protobuf serialization
+	// Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization
+	Version() Version
 
-	// the cached hex string of the calculated sha256sum
-	// for VerifyWithCache
-	sha256sum atomic.Pointer[string]
+	// Name is the human-readable name that identifies this certificate.
+	Name() string
 
-	// the cached public key bytes if they were verified as the signer
-	// for VerifyWithCache
-	signatureVerified atomic.Pointer[[]byte]
-}
+	// Networks is a list of ip addresses and network sizes assigned to this certificate.
+	// If IsCA is true then certificates signed by this CA can only have ip addresses and
+	// networks that are contained by an entry in this list.
+	Networks() []netip.Prefix
 
-type NebulaCertificateDetails struct {
-	Name      string
-	Ips       []*net.IPNet
-	Subnets   []*net.IPNet
-	Groups    []string
-	NotBefore time.Time
-	NotAfter  time.Time
-	PublicKey []byte
-	IsCA      bool
-	Issuer    string
+	// UnsafeNetworks is a list of networks that this host can act as an unsafe router for.
+	// If IsCA is true then certificates signed by this CA can only have networks that are
+	// contained by an entry in this list.
+	UnsafeNetworks() []netip.Prefix
 
-	// Map of groups for faster lookup
-	InvertedGroups map[string]struct{}
+	// Groups is a list of identities that can be used to write more general firewall rule
+	// definitions.
+	// If IsCA is true then certificates signed by this CA can only use groups that are
+	// in this list.
+	Groups() []string
 
-	Curve Curve
-}
+	// IsCA signifies if this is a certificate authority (true) or a host certificate (false).
+	// It is invalid to use a CA certificate as a host certificate.
+	IsCA() bool
 
-type NebulaEncryptedData struct {
-	EncryptionMetadata NebulaEncryptionMetadata
-	Ciphertext         []byte
-}
+	// NotBefore is the time at which this certificate becomes valid.
+	// If IsCA is true then certificate signed by this CA can not have a time before this.
+	NotBefore() time.Time
 
-type NebulaEncryptionMetadata struct {
-	EncryptionAlgorithm string
-	Argon2Parameters    Argon2Parameters
-}
+	// NotAfter is the time at which this certificate becomes invalid.
+	// If IsCA is true then certificate signed by this CA can not have a time after this.
+	NotAfter() time.Time
 
-type m map[string]interface{}
+	// Issuer is the fingerprint of the CA that signed this certificate.
+	// If IsCA is true then this will be empty.
+	Issuer() string
 
-// Returned if we try to unmarshal an encrypted private key without a passphrase
-var ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
+	// PublicKey is the raw bytes to be used in asymmetric cryptographic operations.
+	PublicKey() []byte
 
-// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert
-func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) {
-	if len(b) == 0 {
-		return nil, fmt.Errorf("nil byte array")
-	}
-	var rc RawNebulaCertificate
-	err := proto.Unmarshal(b, &rc)
-	if err != nil {
-		return nil, err
-	}
-
-	if rc.Details == nil {
-		return nil, fmt.Errorf("encoded Details was nil")
-	}
-
-	if len(rc.Details.Ips)%2 != 0 {
-		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
-	}
-
-	if len(rc.Details.Subnets)%2 != 0 {
-		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
-	}
+	// Curve identifies which curve was used for the PublicKey and Signature.
+	Curve() Curve
 
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           rc.Details.Name,
-			Groups:         make([]string, len(rc.Details.Groups)),
-			Ips:            make([]*net.IPNet, len(rc.Details.Ips)/2),
-			Subnets:        make([]*net.IPNet, len(rc.Details.Subnets)/2),
-			NotBefore:      time.Unix(rc.Details.NotBefore, 0),
-			NotAfter:       time.Unix(rc.Details.NotAfter, 0),
-			PublicKey:      make([]byte, len(rc.Details.PublicKey)),
-			IsCA:           rc.Details.IsCA,
-			InvertedGroups: make(map[string]struct{}),
-			Curve:          rc.Details.Curve,
-		},
-		Signature: make([]byte, len(rc.Signature)),
-	}
+	// Signature is the cryptographic seal for all the details of this certificate.
+	// CheckSignature can be used to verify that the details of this certificate are valid.
+	Signature() []byte
 
-	copy(nc.Signature, rc.Signature)
-	copy(nc.Details.Groups, rc.Details.Groups)
-	nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer)
+	// CheckSignature will check that the certificate Signature() matches the
+	// computed signature. A true result means this certificate has not been tampered with.
+	CheckSignature(signingPublicKey []byte) bool
 
-	if len(rc.Details.PublicKey) < publicKeyLen {
-		return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
-	}
-	copy(nc.Details.PublicKey, rc.Details.PublicKey)
+	// Fingerprint returns the hex encoded sha256 sum of the certificate.
+	// This acts as a unique fingerprint and can be used to blocklist certificates.
+	Fingerprint() (string, error)
 
-	for i, rawIp := range rc.Details.Ips {
-		if i%2 == 0 {
-			nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)}
-		} else {
-			nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp))
-		}
-	}
+	// Expired tests if the certificate is valid for the provided time.
+	Expired(t time.Time) bool
 
-	for i, rawIp := range rc.Details.Subnets {
-		if i%2 == 0 {
-			nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)}
-		} else {
-			nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp))
-		}
-	}
+	// VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key.
+	VerifyPrivateKey(curve Curve, privateKey []byte) error
 
-	for _, g := range rc.Details.Groups {
-		nc.Details.InvertedGroups[g] = struct{}{}
-	}
+	// Marshal will return the byte representation of this certificate
+	// This is primarily the format transmitted on the wire.
+	Marshal() ([]byte, error)
 
-	return &nc, nil
-}
+	// MarshalForHandshakes prepares the bytes needed to use directly in a handshake
+	MarshalForHandshakes() ([]byte, error)
 
-// UnmarshalNebulaCertificateFromPEM will unmarshal the first pem block in a byte array, returning any non consumed data
-// or an error on failure
-func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) {
-	p, r := pem.Decode(b)
-	if p == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if p.Type != CertBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
-	}
-	nc, err := UnmarshalNebulaCertificate(p.Bytes)
-	return nc, r, err
-}
+	// MarshalPEM will return a PEM encoded representation of this certificate
+	// This is primarily the format stored on disk
+	MarshalPEM() ([]byte, error)
 
-func MarshalPrivateKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
+	// MarshalJSON will return the json representation of this certificate
+	MarshalJSON() ([]byte, error)
 
-func MarshalSigningPrivateKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
+	// String will return a human-readable representation of this certificate
+	String() string
 
-// MarshalX25519PrivateKey is a simple helper to PEM encode an X25519 private key
-func MarshalX25519PrivateKey(b []byte) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
+	// Copy creates a copy of the certificate
+	Copy() Certificate
 }
 
-// MarshalEd25519PrivateKey is a simple helper to PEM encode an Ed25519 private key
-func MarshalEd25519PrivateKey(key ed25519.PrivateKey) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: key})
+// CachedCertificate represents a verified certificate with some cached fields to improve
+// performance.
+type CachedCertificate struct {
+	Certificate       Certificate
+	InvertedGroups    map[string]struct{}
+	Fingerprint       string
+	signerFingerprint string
 }
 
-func UnmarshalPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var expectedLen int
-	var curve Curve
-	switch k.Type {
-	case X25519PrivateKeyBanner:
-		expectedLen = 32
-		curve = Curve_CURVE25519
-	case P256PrivateKeyBanner:
-		expectedLen = 32
-		curve = Curve_P256
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula private key banner")
-	}
-	if len(k.Bytes) != expectedLen {
-		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
-	}
-	return k.Bytes, r, curve, nil
-}
-
-func UnmarshalSigningPrivateKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var curve Curve
-	switch k.Type {
-	case EncryptedEd25519PrivateKeyBanner:
-		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
-	case EncryptedECDSAP256PrivateKeyBanner:
-		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
-	case Ed25519PrivateKeyBanner:
-		curve = Curve_CURVE25519
-		if len(k.Bytes) != ed25519.PrivateKeySize {
-			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
-		}
-	case ECDSAP256PrivateKeyBanner:
-		curve = Curve_P256
-		if len(k.Bytes) != 32 {
-			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
-		}
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
-	}
-	return k.Bytes, r, curve, nil
-}
-
-// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
-func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
-	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
-	if err != nil {
-		return nil, err
-	}
-
-	b, err = proto.Marshal(&RawNebulaEncryptedData{
-		EncryptionMetadata: &RawNebulaEncryptionMetadata{
-			EncryptionAlgorithm: "AES-256-GCM",
-			Argon2Parameters: &RawNebulaArgon2Parameters{
-				Version:     kdfParams.version,
-				Memory:      kdfParams.Memory,
-				Parallelism: uint32(kdfParams.Parallelism),
-				Iterations:  kdfParams.Iterations,
-				Salt:        kdfParams.salt,
-			},
-		},
-		Ciphertext: ciphertext,
-	})
-	if err != nil {
-		return nil, err
-	}
-
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
-	default:
-		return nil, fmt.Errorf("invalid curve: %v", curve)
-	}
-}
-
-// UnmarshalX25519PrivateKey will try to pem decode an X25519 private key, returning any other bytes b
-// or an error on failure
-func UnmarshalX25519PrivateKey(b []byte) ([]byte, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != X25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 private key banner")
-	}
-	if len(k.Bytes) != publicKeyLen {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 private key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalEd25519PrivateKey will try to pem decode an Ed25519 private key, returning any other bytes b
-// or an error on failure
-func UnmarshalEd25519PrivateKey(b []byte) (ed25519.PrivateKey, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-
-	if k.Type == EncryptedEd25519PrivateKeyBanner {
-		return nil, r, ErrPrivateKeyEncrypted
-	} else if k.Type != Ed25519PrivateKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 private key banner")
-	}
-
-	if len(k.Bytes) != ed25519.PrivateKeySize {
-		return nil, r, fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
-// protobuf-generated struct.
-func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
-	if len(b) == 0 {
-		return nil, fmt.Errorf("nil byte array")
-	}
-	var rned RawNebulaEncryptedData
-	err := proto.Unmarshal(b, &rned)
-	if err != nil {
-		return nil, err
-	}
-
-	if rned.EncryptionMetadata == nil {
-		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
-	}
-
-	if rned.EncryptionMetadata.Argon2Parameters == nil {
-		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
-	}
-
-	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
-	if err != nil {
-		return nil, err
-	}
-
-	ned := NebulaEncryptedData{
-		EncryptionMetadata: NebulaEncryptionMetadata{
-			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
-			Argon2Parameters:    *params,
-		},
-		Ciphertext: rned.Ciphertext,
-	}
-
-	return &ned, nil
-}
-
-func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
-	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
-		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
-	}
-	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
-		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
-	}
-	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
-		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
-	}
-	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
-		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
-	}
-
-	return &Argon2Parameters{
-		version:     rune(params.Version),
-		Memory:      uint32(params.Memory),
-		Parallelism: uint8(params.Parallelism),
-		Iterations:  uint32(params.Iterations),
-		salt:        params.Salt,
-	}, nil
-
-}
-
-// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
-// the given passphrase, returning any other bytes b or an error on failure
-func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
-	var curve Curve
-
-	k, r := pem.Decode(b)
-	if k == nil {
-		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-
-	switch k.Type {
-	case EncryptedEd25519PrivateKeyBanner:
-		curve = Curve_CURVE25519
-	case EncryptedECDSAP256PrivateKeyBanner:
-		curve = Curve_P256
-	default:
-		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
-	}
-
-	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
-	if err != nil {
-		return curve, nil, r, err
-	}
-
-	var bytes []byte
-	switch ned.EncryptionMetadata.EncryptionAlgorithm {
-	case "AES-256-GCM":
-		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
-		if err != nil {
-			return curve, nil, r, err
-		}
-	default:
-		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
-	}
-
-	switch curve {
-	case Curve_CURVE25519:
-		if len(bytes) != ed25519.PrivateKeySize {
-			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
-		}
-	case Curve_P256:
-		if len(bytes) != 32 {
-			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
-		}
-	}
-
-	return curve, bytes, r, nil
-}
-
-func MarshalPublicKey(curve Curve, b []byte) []byte {
-	switch curve {
-	case Curve_CURVE25519:
-		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
-	case Curve_P256:
-		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
-	default:
-		return nil
-	}
-}
-
-// MarshalX25519PublicKey is a simple helper to PEM encode an X25519 public key
-func MarshalX25519PublicKey(b []byte) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
-}
-
-// MarshalEd25519PublicKey is a simple helper to PEM encode an Ed25519 public key
-func MarshalEd25519PublicKey(key ed25519.PublicKey) []byte {
-	return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: key})
-}
-
-func UnmarshalPublicKey(b []byte) ([]byte, []byte, Curve, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	var expectedLen int
-	var curve Curve
-	switch k.Type {
-	case X25519PublicKeyBanner:
-		expectedLen = 32
-		curve = Curve_CURVE25519
-	case P256PublicKeyBanner:
-		// Uncompressed
-		expectedLen = 65
-		curve = Curve_P256
-	default:
-		return nil, r, 0, fmt.Errorf("bytes did not contain a proper nebula public key banner")
-	}
-	if len(k.Bytes) != expectedLen {
-		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
-	}
-	return k.Bytes, r, curve, nil
-}
-
-// UnmarshalX25519PublicKey will try to pem decode an X25519 public key, returning any other bytes b
-// or an error on failure
-func UnmarshalX25519PublicKey(b []byte) ([]byte, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != X25519PublicKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula X25519 public key banner")
-	}
-	if len(k.Bytes) != publicKeyLen {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid X25519 public key")
-	}
-
-	return k.Bytes, r, nil
-}
-
-// UnmarshalEd25519PublicKey will try to pem decode an Ed25519 public key, returning any other bytes b
-// or an error on failure
-func UnmarshalEd25519PublicKey(b []byte) (ed25519.PublicKey, []byte, error) {
-	k, r := pem.Decode(b)
-	if k == nil {
-		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
-	}
-	if k.Type != Ed25519PublicKeyBanner {
-		return nil, r, fmt.Errorf("bytes did not contain a proper nebula Ed25519 public key banner")
-	}
-	if len(k.Bytes) != ed25519.PublicKeySize {
-		return nil, r, fmt.Errorf("key was not 32 bytes, is invalid ed25519 public key")
-	}
-
-	return k.Bytes, r, nil
+func (cc *CachedCertificate) String() string {
+	return cc.Certificate.String()
 }
 
-// Sign signs a nebula cert with the provided private key
-func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error {
-	if curve != nc.Details.Curve {
-		return fmt.Errorf("curve in cert and private key supplied don't match")
+// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
+// Handshakes save space by placing the peers public key in a different part of the packet, we have to
+// reassemble the actual certificate structure with that in mind.
+func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
+	if publicKey == nil {
+		return nil, ErrNoPeerStaticKey
 	}
 
-	b, err := proto.Marshal(nc.getRawDetails())
-	if err != nil {
-		return err
+	if rawCertBytes == nil {
+		return nil, ErrNoPayload
 	}
 
-	var sig []byte
-
-	switch curve {
-	case Curve_CURVE25519:
-		signer := ed25519.PrivateKey(key)
-		sig = ed25519.Sign(signer, b)
-	case Curve_P256:
-		signer := &ecdsa.PrivateKey{
-			PublicKey: ecdsa.PublicKey{
-				Curve: elliptic.P256(),
-			},
-			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
-			D: new(big.Int).SetBytes(key),
-		}
-		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
-		signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
-
-		// We need to hash first for ECDSA
-		// - https://pkg.go.dev/crypto/ecdsa#SignASN1
-		hashed := sha256.Sum256(b)
-		sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
-		if err != nil {
-			return err
-		}
-	default:
-		return fmt.Errorf("invalid curve: %s", nc.Details.Curve)
-	}
-
-	nc.Signature = sig
-	return nil
-}
-
-// CheckSignature verifies the signature against the provided public key
-func (nc *NebulaCertificate) CheckSignature(key []byte) bool {
-	b, err := proto.Marshal(nc.getRawDetails())
+	c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
 	if err != nil {
-		return false
-	}
-	switch nc.Details.Curve {
-	case Curve_CURVE25519:
-		return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature)
-	case Curve_P256:
-		x, y := elliptic.Unmarshal(elliptic.P256(), key)
-		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
-		hashed := sha256.Sum256(b)
-		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature)
-	default:
-		return false
-	}
-}
-
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool {
-	if !useCache {
-		return nc.CheckSignature(key)
-	}
-
-	if v := nc.signatureVerified.Load(); v != nil {
-		return bytes.Equal(*v, key)
-	}
-
-	verified := nc.CheckSignature(key)
-	if verified {
-		keyCopy := make([]byte, len(key))
-		copy(keyCopy, key)
-		nc.signatureVerified.Store(&keyCopy)
+		return nil, fmt.Errorf("error unmarshaling cert: %w", err)
 	}
 
-	return verified
-}
-
-// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false
-func (nc *NebulaCertificate) Expired(t time.Time) bool {
-	return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t)
-}
-
-// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
-	return nc.verify(t, ncp, false)
-}
-
-// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-//
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *NebulaCAPool) (bool, error) {
-	return nc.verify(t, ncp, true)
-}
-
-// ResetCache resets the cache used by VerifyWithCache.
-func (nc *NebulaCertificate) ResetCache() {
-	nc.sha256sum.Store(nil)
-	nc.signatureVerified.Store(nil)
-}
-
-// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
-func (nc *NebulaCertificate) verify(t time.Time, ncp *NebulaCAPool, useCache bool) (bool, error) {
-	if ncp.isBlocklistedWithCache(nc, useCache) {
-		return false, ErrBlockListed
-	}
-
-	signer, err := ncp.GetCAForCert(nc)
+	cc, err := caPool.VerifyCertificate(time.Now(), c)
 	if err != nil {
-		return false, err
-	}
-
-	if signer.Expired(t) {
-		return false, ErrRootExpired
-	}
-
-	if nc.Expired(t) {
-		return false, ErrExpired
+		return nil, fmt.Errorf("certificate validation failed: %w", err)
 	}
 
-	if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) {
-		return false, ErrSignatureMismatch
-	}
-
-	if err := nc.CheckRootConstrains(signer); err != nil {
-		return false, err
-	}
-
-	return true, nil
+	return cc, nil
 }
 
-// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets)
-func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error {
-	// Make sure this cert wasn't valid before the root
-	if signer.Details.NotAfter.Before(nc.Details.NotAfter) {
-		return fmt.Errorf("certificate expires after signing certificate")
-	}
+func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
+	var c Certificate
+	var err error
 
-	// Make sure this cert isn't valid after the root
-	if signer.Details.NotBefore.After(nc.Details.NotBefore) {
-		return fmt.Errorf("certificate is valid before the signing certificate")
-	}
-
-	// If the signer has a limited set of groups make sure the cert only contains a subset
-	if len(signer.Details.InvertedGroups) > 0 {
-		for _, g := range nc.Details.Groups {
-			if _, ok := signer.Details.InvertedGroups[g]; !ok {
-				return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g)
-			}
-		}
-	}
-
-	// If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset
-	if len(signer.Details.Ips) > 0 {
-		for _, ip := range nc.Details.Ips {
-			if !netMatch(ip, signer.Details.Ips) {
-				return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String())
-			}
-		}
-	}
-
-	// If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset
-	if len(signer.Details.Subnets) > 0 {
-		for _, subnet := range nc.Details.Subnets {
-			if !netMatch(subnet, signer.Details.Subnets) {
-				return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet)
-			}
-		}
-	}
-
-	return nil
-}
-
-// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match
-func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error {
-	if curve != nc.Details.Curve {
-		return fmt.Errorf("curve in cert and private key supplied don't match")
-	}
-	if nc.Details.IsCA {
-		switch curve {
-		case Curve_CURVE25519:
-			// the call to PublicKey below will panic slice bounds out of range otherwise
-			if len(key) != ed25519.PrivateKeySize {
-				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
-			}
-
-			if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
-				return fmt.Errorf("public key in cert and private key supplied don't match")
-			}
-		case Curve_P256:
-			privkey, err := ecdh.P256().NewPrivateKey(key)
-			if err != nil {
-				return fmt.Errorf("cannot parse private key as P256")
-			}
-			pub := privkey.PublicKey().Bytes()
-			if !bytes.Equal(pub, nc.Details.PublicKey) {
-				return fmt.Errorf("public key in cert and private key supplied don't match")
-			}
-		default:
-			return fmt.Errorf("invalid curve: %s", curve)
-		}
-		return nil
-	}
-
-	var pub []byte
-	switch curve {
-	case Curve_CURVE25519:
-		var err error
-		pub, err = curve25519.X25519(key, curve25519.Basepoint)
-		if err != nil {
-			return err
-		}
-	case Curve_P256:
-		privkey, err := ecdh.P256().NewPrivateKey(key)
-		if err != nil {
-			return err
-		}
-		pub = privkey.PublicKey().Bytes()
+	switch v {
+	// Implementations must ensure the result is a valid cert!
+	case VersionPre1, Version1:
+		c, err = unmarshalCertificateV1(b, publicKey)
+	case Version2:
+		c, err = unmarshalCertificateV2(b, publicKey, curve)
 	default:
-		return fmt.Errorf("invalid curve: %s", curve)
-	}
-	if !bytes.Equal(pub, nc.Details.PublicKey) {
-		return fmt.Errorf("public key in cert and private key supplied don't match")
-	}
-
-	return nil
-}
-
-// String will return a pretty printed representation of a nebula cert
-func (nc *NebulaCertificate) String() string {
-	if nc == nil {
-		return "NebulaCertificate {}\n"
-	}
-
-	s := "NebulaCertificate {\n"
-	s += "\tDetails {\n"
-	s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name)
-
-	if len(nc.Details.Ips) > 0 {
-		s += "\t\tIps: [\n"
-		for _, ip := range nc.Details.Ips {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tIps: []\n"
-	}
-
-	if len(nc.Details.Subnets) > 0 {
-		s += "\t\tSubnets: [\n"
-		for _, ip := range nc.Details.Subnets {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tSubnets: []\n"
-	}
-
-	if len(nc.Details.Groups) > 0 {
-		s += "\t\tGroups: [\n"
-		for _, g := range nc.Details.Groups {
-			s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tGroups: []\n"
-	}
-
-	s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore)
-	s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter)
-	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA)
-	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer)
-	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey)
-	s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve)
-	s += "\t}\n"
-	fp, err := nc.Sha256Sum()
-	if err == nil {
-		s += fmt.Sprintf("\tFingerprint: %s\n", fp)
+		//TODO: CERT-V2 make a static var
+		return nil, fmt.Errorf("unknown certificate version %d", v)
 	}
-	s += fmt.Sprintf("\tSignature: %x\n", nc.Signature)
-	s += "}"
 
-	return s
-}
-
-// getRawDetails marshals the raw details into protobuf ready struct
-func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails {
-	rd := &RawNebulaCertificateDetails{
-		Name:      nc.Details.Name,
-		Groups:    nc.Details.Groups,
-		NotBefore: nc.Details.NotBefore.Unix(),
-		NotAfter:  nc.Details.NotAfter.Unix(),
-		PublicKey: make([]byte, len(nc.Details.PublicKey)),
-		IsCA:      nc.Details.IsCA,
-		Curve:     nc.Details.Curve,
-	}
-
-	for _, ipNet := range nc.Details.Ips {
-		rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask))
-	}
-
-	for _, ipNet := range nc.Details.Subnets {
-		rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask))
-	}
-
-	copy(rd.PublicKey, nc.Details.PublicKey[:])
-
-	// I know, this is terrible
-	rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer)
-
-	return rd
-}
-
-// Marshal will marshal a nebula cert into a protobuf byte array
-func (nc *NebulaCertificate) Marshal() ([]byte, error) {
-	rc := RawNebulaCertificate{
-		Details:   nc.getRawDetails(),
-		Signature: nc.Signature,
-	}
-
-	return proto.Marshal(&rc)
-}
-
-// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result
-func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) {
-	b, err := nc.Marshal()
 	if err != nil {
 		return nil, err
 	}
-	return pem.EncodeToMemory(&pem.Block{Type: CertBanner, Bytes: b}), nil
-}
-
-// Sha256Sum calculates a sha-256 sum of the marshaled certificate
-func (nc *NebulaCertificate) Sha256Sum() (string, error) {
-	b, err := nc.Marshal()
-	if err != nil {
-		return "", err
-	}
-
-	sum := sha256.Sum256(b)
-	return hex.EncodeToString(sum[:]), nil
-}
-
-// NOTE: This uses an internal cache that will not be invalidated automatically
-// if you manually change any fields in the NebulaCertificate.
-func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) {
-	if !useCache {
-		return nc.Sha256Sum()
-	}
-
-	if s := nc.sha256sum.Load(); s != nil {
-		return *s, nil
-	}
-	s, err := nc.Sha256Sum()
-	if err != nil {
-		return s, err
-	}
-
-	nc.sha256sum.Store(&s)
-	return s, nil
-}
-
-func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) {
-	toString := func(ips []*net.IPNet) []string {
-		s := []string{}
-		for _, ip := range ips {
-			s = append(s, ip.String())
-		}
-		return s
-	}
-
-	fp, _ := nc.Sha256Sum()
-	jc := m{
-		"details": m{
-			"name":      nc.Details.Name,
-			"ips":       toString(nc.Details.Ips),
-			"subnets":   toString(nc.Details.Subnets),
-			"groups":    nc.Details.Groups,
-			"notBefore": nc.Details.NotBefore,
-			"notAfter":  nc.Details.NotAfter,
-			"publicKey": fmt.Sprintf("%x", nc.Details.PublicKey),
-			"isCa":      nc.Details.IsCA,
-			"issuer":    nc.Details.Issuer,
-			"curve":     nc.Details.Curve.String(),
-		},
-		"fingerprint": fp,
-		"signature":   fmt.Sprintf("%x", nc.Signature),
-	}
-	return json.Marshal(jc)
-}
-
-//func (nc *NebulaCertificate) Copy() *NebulaCertificate {
-//	r, err := nc.Marshal()
-//	if err != nil {
-//		//TODO
-//		return nil
-//	}
-//
-//	c, err := UnmarshalNebulaCertificate(r)
-//	return c
-//}
-
-func (nc *NebulaCertificate) Copy() *NebulaCertificate {
-	c := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           nc.Details.Name,
-			Groups:         make([]string, len(nc.Details.Groups)),
-			Ips:            make([]*net.IPNet, len(nc.Details.Ips)),
-			Subnets:        make([]*net.IPNet, len(nc.Details.Subnets)),
-			NotBefore:      nc.Details.NotBefore,
-			NotAfter:       nc.Details.NotAfter,
-			PublicKey:      make([]byte, len(nc.Details.PublicKey)),
-			IsCA:           nc.Details.IsCA,
-			Issuer:         nc.Details.Issuer,
-			InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)),
-		},
-		Signature: make([]byte, len(nc.Signature)),
-	}
-
-	copy(c.Signature, nc.Signature)
-	copy(c.Details.Groups, nc.Details.Groups)
-	copy(c.Details.PublicKey, nc.Details.PublicKey)
-
-	for i, p := range nc.Details.Ips {
-		c.Details.Ips[i] = &net.IPNet{
-			IP:   make(net.IP, len(p.IP)),
-			Mask: make(net.IPMask, len(p.Mask)),
-		}
-		copy(c.Details.Ips[i].IP, p.IP)
-		copy(c.Details.Ips[i].Mask, p.Mask)
-	}
-
-	for i, p := range nc.Details.Subnets {
-		c.Details.Subnets[i] = &net.IPNet{
-			IP:   make(net.IP, len(p.IP)),
-			Mask: make(net.IPMask, len(p.Mask)),
-		}
-		copy(c.Details.Subnets[i].IP, p.IP)
-		copy(c.Details.Subnets[i].Mask, p.Mask)
-	}
-
-	for g := range nc.Details.InvertedGroups {
-		c.Details.InvertedGroups[g] = struct{}{}
-	}
-
-	return c
-}
-
-func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool {
-	for _, net := range rootIps {
-		if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) {
-			return true
-		}
-	}
-
-	return false
-}
 
-func maskContains(caMask, certMask net.IPMask) bool {
-	caM := maskTo4(caMask)
-	cM := maskTo4(certMask)
-	// Make sure forcing to ipv4 didn't nuke us
-	if caM == nil || cM == nil {
-		return false
+	if c.Curve() != curve {
+		return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
 	}
 
-	// Make sure the cert mask is not greater than the ca mask
-	for i := 0; i < len(caMask); i++ {
-		if caM[i] > cM[i] {
-			return false
-		}
-	}
-
-	return true
-}
-
-func maskTo4(ip net.IPMask) net.IPMask {
-	if len(ip) == net.IPv4len {
-		return ip
-	}
-
-	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
-		return ip[12:16]
-	}
-
-	return nil
-}
-
-func isZeros(b []byte) bool {
-	for i := 0; i < len(b); i++ {
-		if b[i] != 0 {
-			return false
-		}
-	}
-	return true
-}
-
-func ip2int(ip []byte) uint32 {
-	if len(ip) == 16 {
-		return binary.BigEndian.Uint32(ip[12:16])
-	}
-	return binary.BigEndian.Uint32(ip)
-}
-
-func int2ip(nn uint32) net.IP {
-	ip := make(net.IP, net.IPv4len)
-	binary.BigEndian.PutUint32(ip, nn)
-	return ip
+	return c, nil
 }

+ 0 - 1230
cert/cert_test.go

@@ -1,1230 +0,0 @@
-package cert
-
-import (
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/elliptic"
-	"crypto/rand"
-	"fmt"
-	"io"
-	"net"
-	"testing"
-	"time"
-
-	"github.com/slackhq/nebula/test"
-	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-	"google.golang.org/protobuf/proto"
-)
-
-func TestMarshalingNebulaCertificate(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-
-	nc2, err := UnmarshalNebulaCertificate(b)
-	assert.Nil(t, err)
-
-	assert.Equal(t, nc.Signature, nc2.Signature)
-	assert.Equal(t, nc.Details.Name, nc2.Details.Name)
-	assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore)
-	assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter)
-	assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey)
-	assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA)
-
-	// IP byte arrays can be 4 or 16 in length so we have to go this route
-	assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips))
-	for i, wIp := range nc.Details.Ips {
-		assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String())
-	}
-
-	assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets))
-	for i, wIp := range nc.Details.Subnets {
-		assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String())
-	}
-
-	assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups)
-}
-
-func TestNebulaCertificate_Sign(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-	}
-
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	assert.Nil(t, err)
-	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(Curve_CURVE25519, priv))
-	assert.True(t, nc.CheckSignature(pub))
-
-	_, err = nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-}
-
-func TestNebulaCertificate_SignP256(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Curve:     Curve_P256,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-	}
-
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	assert.Nil(t, err)
-	assert.False(t, nc.CheckSignature(pub))
-	assert.Nil(t, nc.Sign(Curve_P256, rawPriv))
-	assert.True(t, nc.CheckSignature(pub))
-
-	_, err = nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-}
-
-func TestNebulaCertificate_Expired(t *testing.T) {
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
-			NotAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
-		},
-	}
-
-	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
-	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
-	assert.False(t, nc.Expired(time.Now()))
-}
-
-func TestNebulaCertificate_MarshalJSON(t *testing.T) {
-	time.Local = time.UTC
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
-			NotAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.MarshalJSON()
-	assert.Nil(t, err)
-	assert.Equal(
-		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
-		string(b),
-	)
-}
-
-func TestNebulaCertificate_Verify(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	h, err := ca.Sha256Sum()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.CAs[h] = ca
-
-	f, err := c.Sha256Sum()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is expired")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	h, err := ca.Sha256Sum()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.CAs[h] = ca
-
-	f, err := c.Sha256Sum()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now().Add(time.Minute*6), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate is expired")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_IPs(t *testing.T) {
-	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
-	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	// ip is outside the network
-	cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}}
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_Subnets(t *testing.T) {
-	_, caIp1, _ := net.ParseCIDR("10.0.0.0/16")
-	_, caIp2, _ := net.ParseCIDR("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalToPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	caPool.AddCACertificate(caPem)
-
-	// ip is outside the network
-	cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}}
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err := c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.False(t, v)
-	assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}}
-	cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}}
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	v, err = c.Verify(time.Now(), caPool)
-	assert.True(t, v)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := x25519Keypair()
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{})
-	err = c.VerifyPrivateKey(Curve_P256, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := p256Keypair()
-	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNewCAPoolFromBytes(t *testing.T) {
-	noNewLines := `
-# Current provisional, Remove once everything moves over to the real root.
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-# root-ca01
------BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
------END NEBULA CERTIFICATE-----
-`
-
-	withNewLines := `
-# Current provisional, Remove once everything moves over to the real root.
-
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-
-# root-ca01
-
-
------BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
------END NEBULA CERTIFICATE-----
-
-`
-
-	expired := `
-# expired certificate
------BEGIN NEBULA CERTIFICATE-----
-CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4
-vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie
-WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
------END NEBULA CERTIFICATE-----
-`
-
-	p256 := `
-# p256 certificate
------BEGIN NEBULA CERTIFICATE-----
-CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
-6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H
-76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
-IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
------END NEBULA CERTIFICATE-----
-`
-
-	rootCA := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula root ca",
-		},
-	}
-
-	rootCA01 := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula root ca 01",
-		},
-	}
-
-	rootCAP256 := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "nebula P256 test",
-		},
-	}
-
-	p, err := NewCAPoolFromBytes([]byte(noNewLines))
-	assert.Nil(t, err)
-	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-
-	pp, err := NewCAPoolFromBytes([]byte(withNewLines))
-	assert.Nil(t, err)
-	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-
-	// expired cert, no valid certs
-	ppp, err := NewCAPoolFromBytes([]byte(expired))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
-
-	// expired cert, with valid certs
-	pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
-	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
-	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
-	assert.Equal(t, len(pppp.CAs), 3)
-
-	ppppp, err := NewCAPoolFromBytes([]byte(p256))
-	assert.Nil(t, err)
-	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
-	assert.Equal(t, len(ppppp.CAs), 1)
-}
-
-func appendByteSlices(b ...[]byte) []byte {
-	retSlice := []byte{}
-	for _, v := range b {
-		retSlice = append(retSlice, v...)
-	}
-	return retSlice
-}
-
-func TestUnmrshalCertPEM(t *testing.T) {
-	goodCert := []byte(`
-# A good cert
------BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NEBULA CERTIFICATE-----
-`)
-	badBanner := []byte(`# A bad banner
------BEGIN NOT A NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
------END NOT A NEBULA CERTIFICATE-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
--END NEBULA CERTIFICATE----`)
-
-	certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
-
-	// Success test case
-	cert, rest, err := UnmarshalNebulaCertificateFromPEM(certBundle)
-	assert.NotNil(t, cert)
-	assert.Equal(t, rest, append(badBanner, invalidPem...))
-	assert.Nil(t, err)
-
-	// Fail due to invalid banner.
-	cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
-	assert.Nil(t, cert)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula certificate banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	cert, rest, err = UnmarshalNebulaCertificateFromPEM(rest)
-	assert.Nil(t, cert)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalSigningPrivateKey(t *testing.T) {
-	privKey := []byte(`# A good key
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	privP256Key := []byte(`# A good key
------BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA ECDSA P256 PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NOT A NEBULA PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
--END NEBULA ED25519 PRIVATE KEY-----`)
-
-	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalSigningPrivateKey(keyBundle)
-	assert.Len(t, k, 64)
-	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-	assert.Nil(t, err)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519/ECDSA private key banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalSigningPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
-	passphrase := []byte("DO NOT USE THIS KEY")
-	privKey := []byte(`# A good key
------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
------END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A key which, once decrypted, is too short
------BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
-k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
-GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
-rQr3bdH3Oy/WiYU=
------END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner (not encrypted)
------BEGIN NEBULA ED25519 PRIVATE KEY-----
-bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
-XgLvodMXZJuaFPssp+WwtA==
------END NEBULA ED25519 PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
-oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
-+Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
-qrlJ69wer3ZUHFXA
--END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
-`)
-
-	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
-	assert.Nil(t, err)
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Len(t, k, 64)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-
-	// Fail due to short key
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-
-	// Fail due to invalid banner
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to invalid passphrase
-	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
-	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
-	assert.Nil(t, k)
-	assert.Equal(t, rest, []byte{})
-}
-
-func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
-	// Having proved that decryption works correctly above, we can test the
-	// encryption function produces a value which can be decrypted
-	passphrase := []byte("passphrase")
-	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
-	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
-	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
-	assert.Nil(t, err)
-
-	// Verify the "key" can be decrypted successfully
-	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
-	assert.Len(t, k, 64)
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Equal(t, rest, []byte{})
-	assert.Nil(t, err)
-
-	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
-}
-
-func TestUnmarshalPrivateKey(t *testing.T) {
-	privKey := []byte(`# A good key
------BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA X25519 PRIVATE KEY-----
-`)
-	privP256Key := []byte(`# A good key
------BEGIN NEBULA P256 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA P256 PRIVATE KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA X25519 PRIVATE KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PRIVATE KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA X25519 PRIVATE KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA X25519 PRIVATE KEY-----`)
-
-	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalPrivateKey(keyBundle)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-	assert.Nil(t, err)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Len(t, k, 32)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-	assert.Nil(t, err)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula private key banner")
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalPrivateKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalEd25519PublicKey(t *testing.T) {
-	pubKey := []byte(`# A good key
------BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA ED25519 PUBLIC KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA ED25519 PUBLIC KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PUBLIC KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA ED25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA ED25519 PUBLIC KEY-----`)
-
-	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, err := UnmarshalEd25519PublicKey(keyBundle)
-	assert.Equal(t, len(k), 32)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-
-	// Fail due to short key
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid ed25519 public key")
-
-	// Fail due to invalid banner
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula Ed25519 public key banner")
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, err = UnmarshalEd25519PublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-func TestUnmarshalX25519PublicKey(t *testing.T) {
-	pubKey := []byte(`# A good key
------BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA X25519 PUBLIC KEY-----
-`)
-	pubP256Key := []byte(`# A good key
------BEGIN NEBULA P256 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
-AAAAAAAAAAAAAAAAAAAAAAA=
------END NEBULA P256 PUBLIC KEY-----
-`)
-	shortKey := []byte(`# A short key
------BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
------END NEBULA X25519 PUBLIC KEY-----
-`)
-	invalidBanner := []byte(`# Invalid banner
------BEGIN NOT A NEBULA PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
------END NOT A NEBULA PUBLIC KEY-----
-`)
-	invalidPem := []byte(`# Not a valid PEM format
--BEGIN NEBULA X25519 PUBLIC KEY-----
-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
--END NEBULA X25519 PUBLIC KEY-----`)
-
-	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
-
-	// Success test case
-	k, rest, curve, err := UnmarshalPublicKey(keyBundle)
-	assert.Equal(t, len(k), 32)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_CURVE25519, curve)
-
-	// Success test case
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Equal(t, len(k), 65)
-	assert.Nil(t, err)
-	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
-	assert.Equal(t, Curve_P256, curve)
-
-	// Fail due to short key
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
-	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
-
-	// Fail due to invalid banner
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.EqualError(t, err, "bytes did not contain a proper nebula public key banner")
-	assert.Equal(t, rest, invalidPem)
-
-	// Fail due to ivalid PEM format, because
-	// it's missing the requisite pre-encapsulation boundary.
-	k, rest, curve, err = UnmarshalPublicKey(rest)
-	assert.Nil(t, k)
-	assert.Equal(t, rest, invalidPem)
-	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
-}
-
-// Ensure that upgrading the protobuf library does not change how certificates
-// are marshalled, since this would break signature verification
-func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
-	before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
-	after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name: "testing",
-			Ips: []*net.IPNet{
-				{IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-				{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-			},
-			Subnets: []*net.IPNet{
-				{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-				{IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))},
-				{IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))},
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		Signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-	assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
-
-	b, err = proto.Marshal(nc.getRawDetails())
-	assert.Nil(t, err)
-	//t.Log("Raw cert size:", len(b))
-	assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
-}
-
-func TestNebulaCertificate_Copy(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	assert.Nil(t, err)
-	cc := c.Copy()
-
-	test.AssertDeepCopyEqual(t, c, cc)
-}
-
-func TestUnmarshalNebulaCertificate(t *testing.T) {
-	// Test that we don't panic with an invalid certificate (#332)
-	data := []byte("\x98\x00\x00")
-	_, err := UnmarshalNebulaCertificate(data)
-	assert.EqualError(t, err, "encoded Details was nil")
-}
-
-func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(Curve_CURVE25519, priv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, priv, nil
-}
-
-func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			Curve:          Curve_P256,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = ips
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(Curve_P256, rawPriv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, rawPriv, nil
-}
-
-func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	if len(groups) == 0 {
-		groups = []string{"test-group1", "test-group2", "test-group3"}
-	}
-
-	if len(ips) == 0 {
-		ips = []*net.IPNet{
-			{IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
-			{IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
-			{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
-		}
-	}
-
-	if len(subnets) == 0 {
-		subnets = []*net.IPNet{
-			{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())},
-			{IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())},
-			{IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())},
-		}
-	}
-
-	var pub, rawPriv []byte
-
-	switch ca.Details.Curve {
-	case Curve_CURVE25519:
-		pub, rawPriv = x25519Keypair()
-	case Curve_P256:
-		pub, rawPriv = p256Keypair()
-	default:
-		return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve)
-	}
-
-	nc := &NebulaCertificate{
-		Details: NebulaCertificateDetails{
-			Name:           "testing",
-			Ips:            ips,
-			Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Curve:          ca.Details.Curve,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	return nc, pub, rawPriv, nil
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
-func p256Keypair() ([]byte, []byte) {
-	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
-	if err != nil {
-		panic(err)
-	}
-	pubkey := privkey.PublicKey()
-	return pubkey.Bytes(), privkey.Bytes()
-}

+ 489 - 0
cert/cert_v1.go

@@ -0,0 +1,489 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/binary"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"google.golang.org/protobuf/proto"
+)
+
+const publicKeyLen = 32
+
+type certificateV1 struct {
+	details   detailsV1
+	signature []byte
+}
+
+type detailsV1 struct {
+	name           string
+	networks       []netip.Prefix
+	unsafeNetworks []netip.Prefix
+	groups         []string
+	notBefore      time.Time
+	notAfter       time.Time
+	publicKey      []byte
+	isCA           bool
+	issuer         string
+
+	curve Curve
+}
+
+type m map[string]interface{}
+
+func (c *certificateV1) Version() Version {
+	return Version1
+}
+
+func (c *certificateV1) Curve() Curve {
+	return c.details.curve
+}
+
+func (c *certificateV1) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV1) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV1) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV1) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV1) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV1) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV1) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV1) PublicKey() []byte {
+	return c.details.publicKey
+}
+
+func (c *certificateV1) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV1) Fingerprint() (string, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return "", err
+	}
+
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV1) CheckSignature(key []byte) bool {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return false
+	}
+	switch c.details.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV1) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.details.curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+			}
+
+			if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return fmt.Errorf("cannot parse private key as P256: %w", err)
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.details.publicKey) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return err
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return err
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.details.publicKey) {
+		return fmt.Errorf("public key in cert and private key supplied don't match")
+	}
+
+	return nil
+}
+
+// getRawDetails marshals the raw details into protobuf ready struct
+func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
+	rd := &RawNebulaCertificateDetails{
+		Name:      c.details.name,
+		Groups:    c.details.groups,
+		NotBefore: c.details.notBefore.Unix(),
+		NotAfter:  c.details.notAfter.Unix(),
+		PublicKey: make([]byte, len(c.details.publicKey)),
+		IsCA:      c.details.isCA,
+		Curve:     c.details.curve,
+	}
+
+	for _, ipNet := range c.details.networks {
+		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
+		rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
+	}
+
+	for _, ipNet := range c.details.unsafeNetworks {
+		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
+		rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
+	}
+
+	copy(rd.PublicKey, c.details.publicKey[:])
+
+	// I know, this is terrible
+	rd.Issuer, _ = hex.DecodeString(c.details.issuer)
+
+	return rd
+}
+
+func (c *certificateV1) String() string {
+	b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
+	pubKey := c.details.publicKey
+	c.details.publicKey = nil
+	rawCertNoKey, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	c.details.publicKey = pubKey
+	return rawCertNoKey, nil
+}
+
+func (c *certificateV1) Marshal() ([]byte, error) {
+	rc := RawNebulaCertificate{
+		Details:   c.getRawDetails(),
+		Signature: c.signature,
+	}
+
+	return proto.Marshal(&rc)
+}
+
+func (c *certificateV1) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
+}
+
+func (c *certificateV1) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.marshalJSON())
+}
+
+func (c *certificateV1) marshalJSON() m {
+	fp, _ := c.Fingerprint()
+	return m{
+		"version": Version1,
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"publicKey":      fmt.Sprintf("%x", c.details.publicKey),
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+			"curve":          c.details.curve.String(),
+		},
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}
+}
+
+func (c *certificateV1) Copy() Certificate {
+	nc := &certificateV1{
+		details: detailsV1{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			publicKey: make([]byte, len(c.details.publicKey)),
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+			curve:     c.details.curve,
+		},
+		signature: make([]byte, len(c.signature)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.signature, c.signature)
+	copy(nc.details.publicKey, c.details.publicKey)
+
+	return nc
+}
+
+func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV1{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		publicKey:      t.PublicKey,
+		isCA:           t.IsCA,
+		curve:          t.Curve,
+		issuer:         t.issuer,
+	}
+
+	return c.validate()
+}
+
+func (c *certificateV1) validate() error {
+	// Empty names are allowed
+
+	if len(c.details.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
+	// Continue to allow this behavior
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
+	}
+
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+	}
+
+	// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
+	// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
+	// unsafe networks would result in a different signature.
+
+	return nil
+}
+
+func (c *certificateV1) marshalForSigning() ([]byte, error) {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return nil, err
+	}
+	return b, nil
+}
+
+func (c *certificateV1) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
+// if the publicKey is provided here then it is not required to be present in `b`
+func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rc RawNebulaCertificate
+	err := proto.Unmarshal(b, &rc)
+	if err != nil {
+		return nil, err
+	}
+
+	if rc.Details == nil {
+		return nil, fmt.Errorf("encoded Details was nil")
+	}
+
+	if len(rc.Details.Ips)%2 != 0 {
+		return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found")
+	}
+
+	if len(rc.Details.Subnets)%2 != 0 {
+		return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found")
+	}
+
+	nc := certificateV1{
+		details: detailsV1{
+			name:           rc.Details.Name,
+			groups:         make([]string, len(rc.Details.Groups)),
+			networks:       make([]netip.Prefix, len(rc.Details.Ips)/2),
+			unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
+			notBefore:      time.Unix(rc.Details.NotBefore, 0),
+			notAfter:       time.Unix(rc.Details.NotAfter, 0),
+			publicKey:      make([]byte, len(rc.Details.PublicKey)),
+			isCA:           rc.Details.IsCA,
+			curve:          rc.Details.Curve,
+		},
+		signature: make([]byte, len(rc.Signature)),
+	}
+
+	copy(nc.signature, rc.Signature)
+	copy(nc.details.groups, rc.Details.Groups)
+	nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
+
+	if len(publicKey) > 0 {
+		nc.details.publicKey = publicKey
+	}
+
+	copy(nc.details.publicKey, rc.Details.PublicKey)
+
+	var ip netip.Addr
+	for i, rawIp := range rc.Details.Ips {
+		if i%2 == 0 {
+			ip = int2addr(rawIp)
+		} else {
+			ones, _ := net.IPMask(int2ip(rawIp)).Size()
+			nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
+		}
+	}
+
+	for i, rawIp := range rc.Details.Subnets {
+		if i%2 == 0 {
+			ip = int2addr(rawIp)
+		} else {
+			ones, _ := net.IPMask(int2ip(rawIp)).Size()
+			nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
+		}
+	}
+
+	err = nc.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return &nc, nil
+}
+
+func ip2int(ip []byte) uint32 {
+	if len(ip) == 16 {
+		return binary.BigEndian.Uint32(ip[12:16])
+	}
+	return binary.BigEndian.Uint32(ip)
+}
+
+func int2ip(nn uint32) net.IP {
+	ip := make(net.IP, net.IPv4len)
+	binary.BigEndian.PutUint32(ip, nn)
+	return ip
+}
+
+func addr2int(addr netip.Addr) uint32 {
+	b := addr.Unmap().As4()
+	return binary.BigEndian.Uint32(b[:])
+}
+
+func int2addr(nn uint32) netip.Addr {
+	ip := [4]byte{}
+	binary.BigEndian.PutUint32(ip[:], nn)
+	return netip.AddrFrom4(ip).Unmap()
+}

+ 111 - 111
cert/cert.pb.go → cert/cert_v1.pb.go

@@ -1,8 +1,8 @@
 // Code generated by protoc-gen-go. DO NOT EDIT.
 // versions:
-// 	protoc-gen-go v1.30.0
+// 	protoc-gen-go v1.34.2
 // 	protoc        v3.21.5
-// source: cert.proto
+// source: cert_v1.proto
 
 package cert
 
@@ -50,11 +50,11 @@ func (x Curve) String() string {
 }
 
 func (Curve) Descriptor() protoreflect.EnumDescriptor {
-	return file_cert_proto_enumTypes[0].Descriptor()
+	return file_cert_v1_proto_enumTypes[0].Descriptor()
 }
 
 func (Curve) Type() protoreflect.EnumType {
-	return &file_cert_proto_enumTypes[0]
+	return &file_cert_v1_proto_enumTypes[0]
 }
 
 func (x Curve) Number() protoreflect.EnumNumber {
@@ -63,7 +63,7 @@ func (x Curve) Number() protoreflect.EnumNumber {
 
 // Deprecated: Use Curve.Descriptor instead.
 func (Curve) EnumDescriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{0}
+	return file_cert_v1_proto_rawDescGZIP(), []int{0}
 }
 
 type RawNebulaCertificate struct {
@@ -78,7 +78,7 @@ type RawNebulaCertificate struct {
 func (x *RawNebulaCertificate) Reset() {
 	*x = RawNebulaCertificate{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[0]
+		mi := &file_cert_v1_proto_msgTypes[0]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -91,7 +91,7 @@ func (x *RawNebulaCertificate) String() string {
 func (*RawNebulaCertificate) ProtoMessage() {}
 
 func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[0]
+	mi := &file_cert_v1_proto_msgTypes[0]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -104,7 +104,7 @@ func (x *RawNebulaCertificate) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaCertificate.ProtoReflect.Descriptor instead.
 func (*RawNebulaCertificate) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{0}
+	return file_cert_v1_proto_rawDescGZIP(), []int{0}
 }
 
 func (x *RawNebulaCertificate) GetDetails() *RawNebulaCertificateDetails {
@@ -143,7 +143,7 @@ type RawNebulaCertificateDetails struct {
 func (x *RawNebulaCertificateDetails) Reset() {
 	*x = RawNebulaCertificateDetails{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[1]
+		mi := &file_cert_v1_proto_msgTypes[1]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -156,7 +156,7 @@ func (x *RawNebulaCertificateDetails) String() string {
 func (*RawNebulaCertificateDetails) ProtoMessage() {}
 
 func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[1]
+	mi := &file_cert_v1_proto_msgTypes[1]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -169,7 +169,7 @@ func (x *RawNebulaCertificateDetails) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaCertificateDetails.ProtoReflect.Descriptor instead.
 func (*RawNebulaCertificateDetails) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{1}
+	return file_cert_v1_proto_rawDescGZIP(), []int{1}
 }
 
 func (x *RawNebulaCertificateDetails) GetName() string {
@@ -254,7 +254,7 @@ type RawNebulaEncryptedData struct {
 func (x *RawNebulaEncryptedData) Reset() {
 	*x = RawNebulaEncryptedData{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[2]
+		mi := &file_cert_v1_proto_msgTypes[2]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -267,7 +267,7 @@ func (x *RawNebulaEncryptedData) String() string {
 func (*RawNebulaEncryptedData) ProtoMessage() {}
 
 func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[2]
+	mi := &file_cert_v1_proto_msgTypes[2]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -280,7 +280,7 @@ func (x *RawNebulaEncryptedData) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaEncryptedData.ProtoReflect.Descriptor instead.
 func (*RawNebulaEncryptedData) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{2}
+	return file_cert_v1_proto_rawDescGZIP(), []int{2}
 }
 
 func (x *RawNebulaEncryptedData) GetEncryptionMetadata() *RawNebulaEncryptionMetadata {
@@ -309,7 +309,7 @@ type RawNebulaEncryptionMetadata struct {
 func (x *RawNebulaEncryptionMetadata) Reset() {
 	*x = RawNebulaEncryptionMetadata{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[3]
+		mi := &file_cert_v1_proto_msgTypes[3]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -322,7 +322,7 @@ func (x *RawNebulaEncryptionMetadata) String() string {
 func (*RawNebulaEncryptionMetadata) ProtoMessage() {}
 
 func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[3]
+	mi := &file_cert_v1_proto_msgTypes[3]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -335,7 +335,7 @@ func (x *RawNebulaEncryptionMetadata) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaEncryptionMetadata.ProtoReflect.Descriptor instead.
 func (*RawNebulaEncryptionMetadata) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{3}
+	return file_cert_v1_proto_rawDescGZIP(), []int{3}
 }
 
 func (x *RawNebulaEncryptionMetadata) GetEncryptionAlgorithm() string {
@@ -367,7 +367,7 @@ type RawNebulaArgon2Parameters struct {
 func (x *RawNebulaArgon2Parameters) Reset() {
 	*x = RawNebulaArgon2Parameters{}
 	if protoimpl.UnsafeEnabled {
-		mi := &file_cert_proto_msgTypes[4]
+		mi := &file_cert_v1_proto_msgTypes[4]
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		ms.StoreMessageInfo(mi)
 	}
@@ -380,7 +380,7 @@ func (x *RawNebulaArgon2Parameters) String() string {
 func (*RawNebulaArgon2Parameters) ProtoMessage() {}
 
 func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
-	mi := &file_cert_proto_msgTypes[4]
+	mi := &file_cert_v1_proto_msgTypes[4]
 	if protoimpl.UnsafeEnabled && x != nil {
 		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
 		if ms.LoadMessageInfo() == nil {
@@ -393,7 +393,7 @@ func (x *RawNebulaArgon2Parameters) ProtoReflect() protoreflect.Message {
 
 // Deprecated: Use RawNebulaArgon2Parameters.ProtoReflect.Descriptor instead.
 func (*RawNebulaArgon2Parameters) Descriptor() ([]byte, []int) {
-	return file_cert_proto_rawDescGZIP(), []int{4}
+	return file_cert_v1_proto_rawDescGZIP(), []int{4}
 }
 
 func (x *RawNebulaArgon2Parameters) GetVersion() int32 {
@@ -431,87 +431,87 @@ func (x *RawNebulaArgon2Parameters) GetSalt() []byte {
 	return nil
 }
 
-var File_cert_proto protoreflect.FileDescriptor
-
-var file_cert_proto_rawDesc = []byte{
-	0x0a, 0x0a, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x63, 0x65,
-	0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
-	0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a, 0x07, 0x44, 0x65,
-	0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74,
-	0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x52, 0x07,
-	0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69, 0x67, 0x6e, 0x61,
-	0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53, 0x69, 0x67, 0x6e,
-	0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
-	0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65,
-	0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20,
-	0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x49, 0x70, 0x73,
-	0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x53,
-	0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x07, 0x53, 0x75,
-	0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
-	0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x1c, 0x0a,
-	0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03,
-	0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x4e,
-	0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x4e,
-	0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75, 0x62, 0x6c, 0x69,
-	0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50, 0x75, 0x62, 0x6c,
-	0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41, 0x18, 0x08, 0x20,
-	0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06, 0x49, 0x73, 0x73,
-	0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73, 0x73, 0x75, 0x65,
-	0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20, 0x01, 0x28, 0x0e,
-	0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65, 0x52, 0x05, 0x63,
-	0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
-	0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61, 0x74, 0x61, 0x12,
-	0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
-	0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x63, 0x65,
-	0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72,
-	0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x12,
-	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65, 0x78, 0x74,
-	0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74, 0x65,
-	0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61,
-	0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61,
-	0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
-	0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
-	0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72,
-	0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61,
-	0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f,
-	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
-	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x52,
-	0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72,
-	0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x41,
-	0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x12,
-	0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05,
-	0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x6d,
-	0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65, 0x6d, 0x6f, 0x72,
-	0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d,
-	0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
-	0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e,
-	0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69,
-	0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28,
-	0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75, 0x72, 0x76, 0x65,
-	0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31, 0x39, 0x10, 0x00,
-	0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a, 0x1e, 0x67, 0x69,
-	0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63, 0x6b, 0x68, 0x71,
-	0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62, 0x06, 0x70, 0x72,
-	0x6f, 0x74, 0x6f, 0x33,
+var File_cert_v1_proto protoreflect.FileDescriptor
+
+var file_cert_v1_proto_rawDesc = []byte{
+	0x0a, 0x0d, 0x63, 0x65, 0x72, 0x74, 0x5f, 0x76, 0x31, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
+	0x04, 0x63, 0x65, 0x72, 0x74, 0x22, 0x71, 0x0a, 0x14, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x3b, 0x0a,
+	0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43,
+	0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c,
+	0x73, 0x52, 0x07, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x53, 0x69,
+	0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x53,
+	0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x22, 0x9c, 0x02, 0x0a, 0x1b, 0x52, 0x61, 0x77,
+	0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74,
+	0x65, 0x44, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
+	0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03,
+	0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x03, 0x49, 0x70, 0x73, 0x12, 0x18,
+	0x0a, 0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0d, 0x52,
+	0x07, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x47, 0x72, 0x6f, 0x75,
+	0x70, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73,
+	0x12, 0x1c, 0x0a, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x18, 0x05, 0x20,
+	0x01, 0x28, 0x03, 0x52, 0x09, 0x4e, 0x6f, 0x74, 0x42, 0x65, 0x66, 0x6f, 0x72, 0x65, 0x12, 0x1a,
+	0x0a, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03,
+	0x52, 0x08, 0x4e, 0x6f, 0x74, 0x41, 0x66, 0x74, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x50, 0x75,
+	0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x50,
+	0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x49, 0x73, 0x43, 0x41,
+	0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x49, 0x73, 0x43, 0x41, 0x12, 0x16, 0x0a, 0x06,
+	0x49, 0x73, 0x73, 0x75, 0x65, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x49, 0x73,
+	0x73, 0x75, 0x65, 0x72, 0x12, 0x21, 0x0a, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x18, 0x64, 0x20,
+	0x01, 0x28, 0x0e, 0x32, 0x0b, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x43, 0x75, 0x72, 0x76, 0x65,
+	0x52, 0x05, 0x63, 0x75, 0x72, 0x76, 0x65, 0x22, 0x8b, 0x01, 0x0a, 0x16, 0x52, 0x61, 0x77, 0x4e,
+	0x65, 0x62, 0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x44, 0x61,
+	0x74, 0x61, 0x12, 0x51, 0x0a, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e,
+	0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21,
+	0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x45,
+	0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
+	0x61, 0x52, 0x12, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1e, 0x0a, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65, 0x72, 0x74,
+	0x65, 0x78, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x43, 0x69, 0x70, 0x68, 0x65,
+	0x72, 0x74, 0x65, 0x78, 0x74, 0x22, 0x9c, 0x01, 0x0a, 0x1b, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62,
+	0x75, 0x6c, 0x61, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x65, 0x74,
+	0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x30, 0x0a, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
+	0x69, 0x6f, 0x6e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01,
+	0x28, 0x09, 0x52, 0x13, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x41, 0x6c,
+	0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x4b, 0x0a, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e,
+	0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65, 0x72, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28,
+	0x0b, 0x32, 0x1f, 0x2e, 0x63, 0x65, 0x72, 0x74, 0x2e, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
+	0x72, 0x73, 0x52, 0x10, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65,
+	0x74, 0x65, 0x72, 0x73, 0x22, 0xa3, 0x01, 0x0a, 0x19, 0x52, 0x61, 0x77, 0x4e, 0x65, 0x62, 0x75,
+	0x6c, 0x61, 0x41, 0x72, 0x67, 0x6f, 0x6e, 0x32, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x65, 0x74, 0x65,
+	0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
+	0x01, 0x28, 0x05, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06,
+	0x6d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6d, 0x65,
+	0x6d, 0x6f, 0x72, 0x79, 0x12, 0x20, 0x0a, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c, 0x6c, 0x65, 0x6c,
+	0x69, 0x73, 0x6d, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x70, 0x61, 0x72, 0x61, 0x6c,
+	0x6c, 0x65, 0x6c, 0x69, 0x73, 0x6d, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74,
+	0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x69, 0x74, 0x65, 0x72,
+	0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x18, 0x05,
+	0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x73, 0x61, 0x6c, 0x74, 0x2a, 0x21, 0x0a, 0x05, 0x43, 0x75,
+	0x72, 0x76, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x55, 0x52, 0x56, 0x45, 0x32, 0x35, 0x35, 0x31,
+	0x39, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x50, 0x32, 0x35, 0x36, 0x10, 0x01, 0x42, 0x20, 0x5a,
+	0x1e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6c, 0x61, 0x63,
+	0x6b, 0x68, 0x71, 0x2f, 0x6e, 0x65, 0x62, 0x75, 0x6c, 0x61, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x62,
+	0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (
-	file_cert_proto_rawDescOnce sync.Once
-	file_cert_proto_rawDescData = file_cert_proto_rawDesc
+	file_cert_v1_proto_rawDescOnce sync.Once
+	file_cert_v1_proto_rawDescData = file_cert_v1_proto_rawDesc
 )
 
-func file_cert_proto_rawDescGZIP() []byte {
-	file_cert_proto_rawDescOnce.Do(func() {
-		file_cert_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_proto_rawDescData)
+func file_cert_v1_proto_rawDescGZIP() []byte {
+	file_cert_v1_proto_rawDescOnce.Do(func() {
+		file_cert_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_cert_v1_proto_rawDescData)
 	})
-	return file_cert_proto_rawDescData
+	return file_cert_v1_proto_rawDescData
 }
 
-var file_cert_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
-var file_cert_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
-var file_cert_proto_goTypes = []interface{}{
+var file_cert_v1_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
+var file_cert_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
+var file_cert_v1_proto_goTypes = []any{
 	(Curve)(0),                          // 0: cert.Curve
 	(*RawNebulaCertificate)(nil),        // 1: cert.RawNebulaCertificate
 	(*RawNebulaCertificateDetails)(nil), // 2: cert.RawNebulaCertificateDetails
@@ -519,7 +519,7 @@ var file_cert_proto_goTypes = []interface{}{
 	(*RawNebulaEncryptionMetadata)(nil), // 4: cert.RawNebulaEncryptionMetadata
 	(*RawNebulaArgon2Parameters)(nil),   // 5: cert.RawNebulaArgon2Parameters
 }
-var file_cert_proto_depIdxs = []int32{
+var file_cert_v1_proto_depIdxs = []int32{
 	2, // 0: cert.RawNebulaCertificate.Details:type_name -> cert.RawNebulaCertificateDetails
 	0, // 1: cert.RawNebulaCertificateDetails.curve:type_name -> cert.Curve
 	4, // 2: cert.RawNebulaEncryptedData.EncryptionMetadata:type_name -> cert.RawNebulaEncryptionMetadata
@@ -531,13 +531,13 @@ var file_cert_proto_depIdxs = []int32{
 	0, // [0:4] is the sub-list for field type_name
 }
 
-func init() { file_cert_proto_init() }
-func file_cert_proto_init() {
-	if File_cert_proto != nil {
+func init() { file_cert_v1_proto_init() }
+func file_cert_v1_proto_init() {
+	if File_cert_v1_proto != nil {
 		return
 	}
 	if !protoimpl.UnsafeEnabled {
-		file_cert_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[0].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaCertificate); i {
 			case 0:
 				return &v.state
@@ -549,7 +549,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[1].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaCertificateDetails); i {
 			case 0:
 				return &v.state
@@ -561,7 +561,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[2].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaEncryptedData); i {
 			case 0:
 				return &v.state
@@ -573,7 +573,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[3].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaEncryptionMetadata); i {
 			case 0:
 				return &v.state
@@ -585,7 +585,7 @@ func file_cert_proto_init() {
 				return nil
 			}
 		}
-		file_cert_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
+		file_cert_v1_proto_msgTypes[4].Exporter = func(v any, i int) any {
 			switch v := v.(*RawNebulaArgon2Parameters); i {
 			case 0:
 				return &v.state
@@ -602,19 +602,19 @@ func file_cert_proto_init() {
 	out := protoimpl.TypeBuilder{
 		File: protoimpl.DescBuilder{
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
-			RawDescriptor: file_cert_proto_rawDesc,
+			RawDescriptor: file_cert_v1_proto_rawDesc,
 			NumEnums:      1,
 			NumMessages:   5,
 			NumExtensions: 0,
 			NumServices:   0,
 		},
-		GoTypes:           file_cert_proto_goTypes,
-		DependencyIndexes: file_cert_proto_depIdxs,
-		EnumInfos:         file_cert_proto_enumTypes,
-		MessageInfos:      file_cert_proto_msgTypes,
+		GoTypes:           file_cert_v1_proto_goTypes,
+		DependencyIndexes: file_cert_v1_proto_depIdxs,
+		EnumInfos:         file_cert_v1_proto_enumTypes,
+		MessageInfos:      file_cert_v1_proto_msgTypes,
 	}.Build()
-	File_cert_proto = out.File
-	file_cert_proto_rawDesc = nil
-	file_cert_proto_goTypes = nil
-	file_cert_proto_depIdxs = nil
+	File_cert_v1_proto = out.File
+	file_cert_v1_proto_rawDesc = nil
+	file_cert_v1_proto_goTypes = nil
+	file_cert_v1_proto_depIdxs = nil
 }

+ 0 - 0
cert/cert.proto → cert/cert_v1.proto


+ 218 - 0
cert/cert_v1_test.go

@@ -0,0 +1,218 @@
+package cert
+
+import (
+	"fmt"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"google.golang.org/protobuf/proto"
+)
+
+func TestCertificateV1_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	assert.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+
+	assert.Equal(t, nc.Version(), Version1)
+	assert.Equal(t, nc.Curve(), Curve_CURVE25519)
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV1_Expired(t *testing.T) {
+	nc := certificateV1{
+		details: detailsV1{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV1_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
+		string(b),
+	)
+}
+
+func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.NotNil(t, err)
+}
+
+// Ensure that upgrading the protobuf library does not change how certificates
+// are marshalled, since this would break signature verification
+func TestMarshalingCertificateV1Consistency(t *testing.T) {
+	before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
+	after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	require.Nil(t, err)
+	assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
+
+	b, err = proto.Marshal(nc.getRawDetails())
+	assert.Nil(t, err)
+	assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
+}
+
+func TestCertificateV1_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV1(t *testing.T) {
+	// Test that we don't panic with an invalid certificate (#332)
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV1(data, nil)
+	assert.EqualError(t, err, "encoded Details was nil")
+}
+
+func appendByteSlices(b ...[]byte) []byte {
+	retSlice := []byte{}
+	for _, v := range b {
+		retSlice = append(retSlice, v...)
+	}
+	return retSlice
+}
+
+func mustParsePrefixUnmapped(s string) netip.Prefix {
+	prefix := netip.MustParsePrefix(s)
+	return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
+}

+ 37 - 0
cert/cert_v2.asn1

@@ -0,0 +1,37 @@
+Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
+
+Name ::= UTF8String (SIZE (1..253))
+Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
+Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
+Curve ::= ENUMERATED {
+    curve25519 (0),
+    p256 (1)
+}
+
+-- The maximum size of a certificate must not exceed 65536 bytes
+Certificate ::= SEQUENCE {
+    details OCTET STRING,
+    curve Curve DEFAULT curve25519,
+    publicKey OCTET STRING,
+    -- signature(details + curve + publicKey) using the appropriate method for curve
+    signature OCTET STRING
+}
+
+Details ::= SEQUENCE {
+    name Name,
+
+    -- At least 1 ipv4 or ipv6 address must be present if isCA is false
+    networks SEQUENCE OF Network OPTIONAL,
+    unsafeNetworks SEQUENCE OF Network OPTIONAL,
+    groups SEQUENCE OF Name OPTIONAL,
+    isCA BOOLEAN DEFAULT false,
+    notBefore Time,
+    notAfter Time,
+
+    -- issuer is only required if isCA is false, if isCA is true then it must not be present
+    issuer OCTET STRING OPTIONAL,
+    ...
+    -- New fields can be added below here
+}
+
+END

+ 730 - 0
cert/cert_v2.go

@@ -0,0 +1,730 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net/netip"
+	"slices"
+	"time"
+
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+	"golang.org/x/crypto/curve25519"
+)
+
+const (
+	classConstructed     = 0x20
+	classContextSpecific = 0x80
+
+	TagCertDetails   = 0 | classConstructed | classContextSpecific
+	TagCertCurve     = 1 | classContextSpecific
+	TagCertPublicKey = 2 | classContextSpecific
+	TagCertSignature = 3 | classContextSpecific
+
+	TagDetailsName           = 0 | classContextSpecific
+	TagDetailsNetworks       = 1 | classConstructed | classContextSpecific
+	TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
+	TagDetailsGroups         = 3 | classConstructed | classContextSpecific
+	TagDetailsIsCA           = 4 | classContextSpecific
+	TagDetailsNotBefore      = 5 | classContextSpecific
+	TagDetailsNotAfter       = 6 | classContextSpecific
+	TagDetailsIssuer         = 7 | classContextSpecific
+)
+
+const (
+	// MaxCertificateSize is the maximum length a valid certificate can be
+	MaxCertificateSize = 65536
+
+	// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
+	MaxNameLength = 253
+
+	// MaxNetworkLength is the maximum length a network value can be.
+	// 16 bytes for an ipv6 address + 1 byte for the prefix length
+	MaxNetworkLength = 17
+)
+
+type certificateV2 struct {
+	details detailsV2
+
+	// RawDetails contains the entire asn.1 DER encoded Details struct
+	// This is to benefit forwards compatibility in signature checking.
+	// signature(RawDetails + Curve + PublicKey) == Signature
+	rawDetails []byte
+	curve      Curve
+	publicKey  []byte
+	signature  []byte
+}
+
+type detailsV2 struct {
+	name           string
+	networks       []netip.Prefix // MUST BE SORTED
+	unsafeNetworks []netip.Prefix // MUST BE SORTED
+	groups         []string
+	isCA           bool
+	notBefore      time.Time
+	notAfter       time.Time
+	issuer         string
+}
+
+func (c *certificateV2) Version() Version {
+	return Version2
+}
+
+func (c *certificateV2) Curve() Curve {
+	return c.curve
+}
+
+func (c *certificateV2) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV2) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV2) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV2) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV2) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV2) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV2) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV2) PublicKey() []byte {
+	return c.publicKey
+}
+
+func (c *certificateV2) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV2) Fingerprint() (string, error) {
+	if len(c.rawDetails) == 0 {
+		return "", ErrMissingDetails
+	}
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV2) CheckSignature(key []byte) bool {
+	if len(c.rawDetails) == 0 {
+		return false
+	}
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+
+	switch c.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV2) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.curve {
+		return ErrPublicPrivateCurveMismatch
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return ErrInvalidPrivateKey
+			}
+
+			if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return ErrInvalidPrivateKey
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.publicKey) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.publicKey) {
+		return ErrPublicPrivateKeyMismatch
+	}
+
+	return nil
+}
+
+func (c *certificateV2) String() string {
+	mb, err := c.marshalJSON()
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+
+	b, err := json.MarshalIndent(mb, "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Skipping the curve and public key since those come across in a different part of the handshake
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) Marshal() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Add the curve only if its not the default value
+		if c.curve != Curve_CURVE25519 {
+			b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
+				b.AddBytes([]byte{byte(c.curve)})
+			})
+		}
+
+		// Add the public key if it is not empty
+		if c.publicKey != nil {
+			b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
+				b.AddBytes(c.publicKey)
+			})
+		}
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
+}
+
+func (c *certificateV2) MarshalJSON() ([]byte, error) {
+	b, err := c.marshalJSON()
+	if err != nil {
+		return nil, err
+	}
+	return json.Marshal(b)
+}
+
+func (c *certificateV2) marshalJSON() (m, error) {
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, err
+	}
+
+	return m{
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+		},
+		"version":     Version2,
+		"publicKey":   fmt.Sprintf("%x", c.publicKey),
+		"curve":       c.curve.String(),
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}, nil
+}
+
+func (c *certificateV2) Copy() Certificate {
+	nc := &certificateV2{
+		details: detailsV2{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+		},
+		curve:      c.curve,
+		publicKey:  make([]byte, len(c.publicKey)),
+		signature:  make([]byte, len(c.signature)),
+		rawDetails: make([]byte, len(c.rawDetails)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.rawDetails, c.rawDetails)
+	copy(nc.signature, c.signature)
+	copy(nc.publicKey, c.publicKey)
+
+	return nc
+}
+
+func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV2{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		isCA:           t.IsCA,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		issuer:         t.issuer,
+	}
+	c.curve = t.Curve
+	c.publicKey = t.PublicKey
+	return c.validate()
+}
+
+func (c *certificateV2) validate() error {
+	// Empty names are allowed
+
+	if len(c.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
+	}
+
+	hasV4Networks := false
+	hasV6Networks := false
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+
+		if network.Addr().Is4In6() {
+			return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
+		}
+
+		hasV4Networks = hasV4Networks || network.Addr().Is4()
+		hasV6Networks = hasV6Networks || network.Addr().Is6()
+	}
+
+	slices.SortFunc(c.details.networks, comparePrefix)
+	err := findDuplicatePrefix(c.details.networks)
+	if err != nil {
+		return err
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+
+		if !c.details.isCA {
+			if network.Addr().Is6() {
+				if !hasV6Networks {
+					return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
+				}
+			} else if network.Addr().Is4() {
+				if !hasV4Networks {
+					return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
+				}
+			}
+		}
+	}
+
+	slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
+	err = findDuplicatePrefix(c.details.unsafeNetworks)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *certificateV2) marshalForSigning() ([]byte, error) {
+	d, err := c.details.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
+	}
+	c.rawDetails = d
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	return b, nil
+}
+
+func (c *certificateV2) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+func (d *detailsV2) Marshal() ([]byte, error) {
+	var b cryptobyte.Builder
+	var err error
+
+	// Details are a structure
+	b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
+
+		// Add the name
+		b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
+			b.AddBytes([]byte(d.name))
+		})
+
+		// Add the networks if any exist
+		if len(d.networks) > 0 {
+			b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.networks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add the unsafe networks if any exist
+		if len(d.unsafeNetworks) > 0 {
+			b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.unsafeNetworks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add groups if any exist
+		if len(d.groups) > 0 {
+			b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
+				for _, group := range d.groups {
+					b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
+						b.AddBytes([]byte(group))
+					})
+				}
+			})
+		}
+
+		// Add IsCA only if true
+		if d.isCA {
+			b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
+				b.AddUint8(0xff)
+			})
+		}
+
+		// Add not before
+		b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
+
+		// Add not after
+		b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
+
+		// Add the issuer if present
+		if d.issuer != "" {
+			issuerBytes, innerErr := hex.DecodeString(d.issuer)
+			if innerErr != nil {
+				err = fmt.Errorf("failed to decode issuer: %w", innerErr)
+				return
+			}
+			b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
+				b.AddBytes(issuerBytes)
+			})
+		}
+	})
+
+	if err != nil {
+		return nil, err
+	}
+
+	return b.Bytes()
+}
+
+func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
+	l := len(b)
+	if l == 0 || l > MaxCertificateSize {
+		return nil, ErrBadFormat
+	}
+
+	input := cryptobyte.String(b)
+	// Open the envelope
+	if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the cert details, we need to preserve the tag and length
+	var rawDetails cryptobyte.String
+	if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	//Maybe grab the curve
+	var rawCurve byte
+	if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
+		return nil, ErrBadFormat
+	}
+	curve = Curve(rawCurve)
+
+	// Maybe grab the public key
+	var rawPublicKey cryptobyte.String
+	if len(publicKey) > 0 {
+		rawPublicKey = publicKey
+	} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
+		return nil, ErrBadFormat
+	}
+
+	if len(rawPublicKey) == 0 {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the signature
+	var rawSignature cryptobyte.String
+	if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Finally unmarshal the details
+	details, err := unmarshalDetails(rawDetails)
+	if err != nil {
+		return nil, err
+	}
+
+	c := &certificateV2{
+		details:    details,
+		rawDetails: rawDetails,
+		curve:      curve,
+		publicKey:  rawPublicKey,
+		signature:  rawSignature,
+	}
+
+	err = c.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
+}
+
+func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
+	// Open the envelope
+	if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the name
+	var name cryptobyte.String
+	if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the network addresses
+	var subString cryptobyte.String
+	var found bool
+
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var networks []netip.Prefix
+	var val cryptobyte.String
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			networks = append(networks, n)
+		}
+	}
+
+	// Read out any unsafe networks
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var unsafeNetworks []netip.Prefix
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			unsafeNetworks = append(unsafeNetworks, n)
+		}
+	}
+
+	// Read out any groups
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var groups []string
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
+				return detailsV2{}, ErrBadFormat
+			}
+			groups = append(groups, string(val))
+		}
+	}
+
+	// Read out IsCA
+	var isCa bool
+	if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read not before and not after
+	var notBefore int64
+	if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var notAfter int64
+	if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read issuer
+	var issuer cryptobyte.String
+	if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	return detailsV2{
+		name:           string(name),
+		networks:       networks,
+		unsafeNetworks: unsafeNetworks,
+		groups:         groups,
+		isCA:           isCa,
+		notBefore:      time.Unix(notBefore, 0),
+		notAfter:       time.Unix(notAfter, 0),
+		issuer:         hex.EncodeToString(issuer),
+	}, nil
+}

+ 267 - 0
cert/cert_v2_test.go

@@ -0,0 +1,267 @@
+package cert
+
+import (
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
+	"slices"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCertificateV2_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	nc.rawDetails = db
+
+	b, err := nc.Marshal()
+	require.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
+	assert.Nil(t, err)
+
+	assert.Equal(t, nc.Version(), Version2)
+	assert.Equal(t, nc.Curve(), Curve_CURVE25519)
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+	assert.Equal(t, nc.Issuer(), nc2.Issuer())
+
+	// unmarshalling will sort networks and unsafeNetworks, we need to do the same
+	// but first make sure it fails
+	assert.NotEqual(t, nc.Networks(), nc2.Networks())
+	assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	slices.SortFunc(nc.details.networks, comparePrefix)
+	slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV2_Expired(t *testing.T) {
+	nc := certificateV2{
+		details: detailsV2{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV2_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedf1234567890abcedf")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			isCA:      false,
+			issuer:    "1234567890abcedf1234567890abcedf",
+		},
+		publicKey: pubKey,
+		signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
+	}
+
+	b, err := nc.MarshalJSON()
+	assert.ErrorIs(t, err, ErrMissingDetails)
+
+	rd, err := nc.details.Marshal()
+	assert.NoError(t, err)
+
+	nc.rawDetails = rd
+	b, err = nc.MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
+		string(b),
+	)
+}
+
+func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	_, caKey2, err := ed25519.GenerateKey(rand.Reader)
+	require.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	ac, ok := c.(*certificateV2)
+	require.True(t, ok)
+	ac.curve = Curve(99)
+	err = c.VerifyPrivateKey(Curve(99), priv2)
+	assert.EqualError(t, err, "invalid curve: 99")
+
+	ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv)
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	aCa, ok := ca2.(*certificateV2)
+	require.True(t, ok)
+	aCa.curve = Curve(99)
+	err = aCa.VerifyPrivateKey(Curve(99), priv2)
+	assert.EqualError(t, err, "invalid curve: 99")
+
+}
+
+func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestCertificateV2_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV2(t *testing.T) {
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
+	assert.EqualError(t, err, "bad wire format")
+}
+
+func TestCertificateV2_marshalForSigningStability(t *testing.T) {
+	before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
+	after := before.Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
+	expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
+	require.NoError(t, err)
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	assert.Equal(t, expectedRawDetails, db)
+
+	expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
+	b, err := nc.marshalForSigning()
+	require.NoError(t, err)
+	assert.Equal(t, expectedForSigning, b)
+}

+ 159 - 2
cert/crypto.go

@@ -3,14 +3,28 @@ package cert
 import (
 	"crypto/aes"
 	"crypto/cipher"
+	"crypto/ed25519"
 	"crypto/rand"
+	"encoding/pem"
 	"fmt"
 	"io"
+	"math"
 
 	"golang.org/x/crypto/argon2"
+	"google.golang.org/protobuf/proto"
 )
 
-// KDF factors
+type NebulaEncryptedData struct {
+	EncryptionMetadata NebulaEncryptionMetadata
+	Ciphertext         []byte
+}
+
+type NebulaEncryptionMetadata struct {
+	EncryptionAlgorithm string
+	Argon2Parameters    Argon2Parameters
+}
+
+// Argon2Parameters KDF factors
 type Argon2Parameters struct {
 	version     rune
 	Memory      uint32 // KiB
@@ -19,7 +33,7 @@ type Argon2Parameters struct {
 	salt        []byte
 }
 
-// Returns a new Argon2Parameters object with current version set
+// NewArgon2Parameters Returns a new Argon2Parameters object with current version set
 func NewArgon2Parameters(memory uint32, parallelism uint8, iterations uint32) *Argon2Parameters {
 	return &Argon2Parameters{
 		version:     argon2.Version,
@@ -141,3 +155,146 @@ func splitNonceCiphertext(blob []byte, nonceSize int) ([]byte, []byte, error) {
 
 	return blob[:nonceSize], blob[nonceSize:], nil
 }
+
+// EncryptAndMarshalSigningPrivateKey is a simple helper to encrypt and PEM encode a private key
+func EncryptAndMarshalSigningPrivateKey(curve Curve, b []byte, passphrase []byte, kdfParams *Argon2Parameters) ([]byte, error) {
+	ciphertext, err := aes256Encrypt(passphrase, kdfParams, b)
+	if err != nil {
+		return nil, err
+	}
+
+	b, err = proto.Marshal(&RawNebulaEncryptedData{
+		EncryptionMetadata: &RawNebulaEncryptionMetadata{
+			EncryptionAlgorithm: "AES-256-GCM",
+			Argon2Parameters: &RawNebulaArgon2Parameters{
+				Version:     kdfParams.version,
+				Memory:      kdfParams.Memory,
+				Parallelism: uint32(kdfParams.Parallelism),
+				Iterations:  kdfParams.Iterations,
+				Salt:        kdfParams.salt,
+			},
+		},
+		Ciphertext: ciphertext,
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedEd25519PrivateKeyBanner, Bytes: b}), nil
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: EncryptedECDSAP256PrivateKeyBanner, Bytes: b}), nil
+	default:
+		return nil, fmt.Errorf("invalid curve: %v", curve)
+	}
+}
+
+// UnmarshalNebulaEncryptedData will unmarshal a protobuf byte representation of a nebula cert into its
+// protobuf-generated struct.
+func UnmarshalNebulaEncryptedData(b []byte) (*NebulaEncryptedData, error) {
+	if len(b) == 0 {
+		return nil, fmt.Errorf("nil byte array")
+	}
+	var rned RawNebulaEncryptedData
+	err := proto.Unmarshal(b, &rned)
+	if err != nil {
+		return nil, err
+	}
+
+	if rned.EncryptionMetadata == nil {
+		return nil, fmt.Errorf("encoded EncryptionMetadata was nil")
+	}
+
+	if rned.EncryptionMetadata.Argon2Parameters == nil {
+		return nil, fmt.Errorf("encoded Argon2Parameters was nil")
+	}
+
+	params, err := unmarshalArgon2Parameters(rned.EncryptionMetadata.Argon2Parameters)
+	if err != nil {
+		return nil, err
+	}
+
+	ned := NebulaEncryptedData{
+		EncryptionMetadata: NebulaEncryptionMetadata{
+			EncryptionAlgorithm: rned.EncryptionMetadata.EncryptionAlgorithm,
+			Argon2Parameters:    *params,
+		},
+		Ciphertext: rned.Ciphertext,
+	}
+
+	return &ned, nil
+}
+
+func unmarshalArgon2Parameters(params *RawNebulaArgon2Parameters) (*Argon2Parameters, error) {
+	if params.Version < math.MinInt32 || params.Version > math.MaxInt32 {
+		return nil, fmt.Errorf("Argon2Parameters Version must be at least %d and no more than %d", math.MinInt32, math.MaxInt32)
+	}
+	if params.Memory <= 0 || params.Memory > math.MaxUint32 {
+		return nil, fmt.Errorf("Argon2Parameters Memory must be be greater than 0 and no more than %d KiB", uint32(math.MaxUint32))
+	}
+	if params.Parallelism <= 0 || params.Parallelism > math.MaxUint8 {
+		return nil, fmt.Errorf("Argon2Parameters Parallelism must be be greater than 0 and no more than %d", math.MaxUint8)
+	}
+	if params.Iterations <= 0 || params.Iterations > math.MaxUint32 {
+		return nil, fmt.Errorf("-argon-iterations must be be greater than 0 and no more than %d", uint32(math.MaxUint32))
+	}
+
+	return &Argon2Parameters{
+		version:     params.Version,
+		Memory:      params.Memory,
+		Parallelism: uint8(params.Parallelism),
+		Iterations:  params.Iterations,
+		salt:        params.Salt,
+	}, nil
+
+}
+
+// DecryptAndUnmarshalSigningPrivateKey will try to pem decode and decrypt an Ed25519/ECDSA private key with
+// the given passphrase, returning any other bytes b or an error on failure
+func DecryptAndUnmarshalSigningPrivateKey(passphrase, b []byte) (Curve, []byte, []byte, error) {
+	var curve Curve
+
+	k, r := pem.Decode(b)
+	if k == nil {
+		return curve, nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+	case EncryptedECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+	default:
+		return curve, nil, r, fmt.Errorf("bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	}
+
+	ned, err := UnmarshalNebulaEncryptedData(k.Bytes)
+	if err != nil {
+		return curve, nil, r, err
+	}
+
+	var bytes []byte
+	switch ned.EncryptionMetadata.EncryptionAlgorithm {
+	case "AES-256-GCM":
+		bytes, err = aes256Decrypt(passphrase, &ned.EncryptionMetadata.Argon2Parameters, ned.Ciphertext)
+		if err != nil {
+			return curve, nil, r, err
+		}
+	default:
+		return curve, nil, r, fmt.Errorf("unsupported encryption algorithm: %s", ned.EncryptionMetadata.EncryptionAlgorithm)
+	}
+
+	switch curve {
+	case Curve_CURVE25519:
+		if len(bytes) != ed25519.PrivateKeySize {
+			return curve, nil, r, fmt.Errorf("key was not %d bytes, is invalid ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case Curve_P256:
+		if len(bytes) != 32 {
+			return curve, nil, r, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	}
+
+	return curve, bytes, r, nil
+}

+ 87 - 0
cert/crypto_test.go

@@ -23,3 +23,90 @@ func TestNewArgon2Parameters(t *testing.T) {
 		Iterations:  1,
 	}, p)
 }
+
+func TestDecryptAndUnmarshalSigningPrivateKey(t *testing.T) {
+	passphrase := []byte("DO NOT USE THIS KEY")
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A key which, once decrypted, is too short
+-----BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCoga5h8owMEBWRSMMJKzuUvWce7
+k0qlBkQmCxiuLh80MuASW70YcKt8jeEIS2axo2V6zAKA9TSMcCsJW1kDDXEtL/xe
+GLF5T7sDl5COp4LU3pGxpV+KoeQ/S3gQCAAcnaOtnJQX+aSDnbO3jCHyP7U9CHbs
+rQr3bdH3Oy/WiYU=
+-----END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner (not encrypted)
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+bWRp2CTVFhW9HD/qCd28ltDgK3w8VXSeaEYczDWos8sMUBqDb9jP3+NYwcS4lURG
+XgLvodMXZJuaFPssp+WwtA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+CjwKC0FFUy0yNTYtR0NNEi0IExCAgIABGAEgBCognnjujd67Vsv99p22wfAjQaDT
+oCMW1mdjkU3gACKNW4MSXOWR9Sts4C81yk1RUku2gvGKs3TB9LYoklLsIizSYOLl
++Vs//O1T0I1Xbml2XBAROsb/VSoDln/6LMqR4B6fn6B3GOsLBBqRI8daDl9lRMPB
+qrlJ69wer3ZUHFXA
+-END NEBULA ED25519 ENCRYPTED PRIVATE KEY-----
+`)
+
+	keyBundle := appendByteSlices(privKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
+	assert.Nil(t, err)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+
+	// Fail due to invalid banner
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to invalid passphrase
+	curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
+	assert.EqualError(t, err, "invalid passphrase or corrupt private key")
+	assert.Nil(t, k)
+	assert.Equal(t, rest, []byte{})
+}
+
+func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
+	// Having proved that decryption works correctly above, we can test the
+	// encryption function produces a value which can be decrypted
+	passphrase := []byte("passphrase")
+	bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
+	kdfParams := NewArgon2Parameters(64*1024, 4, 3)
+	key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
+	assert.Nil(t, err)
+
+	// Verify the "key" can be decrypted successfully
+	curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
+	assert.Len(t, k, 64)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, []byte{})
+	assert.Nil(t, err)
+
+	// EncryptAndMarshalEd25519PrivateKey does not create any errors itself
+}

+ 41 - 6
cert/errors.go

@@ -2,13 +2,48 @@ package cert
 
 import (
 	"errors"
+	"fmt"
 )
 
 var (
-	ErrRootExpired       = errors.New("root certificate is expired")
-	ErrExpired           = errors.New("certificate is expired")
-	ErrNotCA             = errors.New("certificate is not a CA")
-	ErrNotSelfSigned     = errors.New("certificate is not self-signed")
-	ErrBlockListed       = errors.New("certificate is in the block list")
-	ErrSignatureMismatch = errors.New("certificate signature did not match")
+	ErrBadFormat                  = errors.New("bad wire format")
+	ErrRootExpired                = errors.New("root certificate is expired")
+	ErrExpired                    = errors.New("certificate is expired")
+	ErrNotCA                      = errors.New("certificate is not a CA")
+	ErrNotSelfSigned              = errors.New("certificate is not self-signed")
+	ErrBlockListed                = errors.New("certificate is in the block list")
+	ErrFingerprintMismatch        = errors.New("certificate fingerprint did not match")
+	ErrSignatureMismatch          = errors.New("certificate signature did not match")
+	ErrInvalidPublicKey           = errors.New("invalid public key")
+	ErrInvalidPrivateKey          = errors.New("invalid private key")
+	ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
+	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
+	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
+	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
+
+	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
+	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")
+	ErrInvalidPEMX25519PublicKeyBanner   = errors.New("bytes did not contain a proper X25519 public key banner")
+	ErrInvalidPEMX25519PrivateKeyBanner  = errors.New("bytes did not contain a proper X25519 private key banner")
+	ErrInvalidPEMEd25519PublicKeyBanner  = errors.New("bytes did not contain a proper Ed25519 public key banner")
+	ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
+
+	ErrNoPeerStaticKey = errors.New("no peer static key was present")
+	ErrNoPayload       = errors.New("provided payload was empty")
+
+	ErrMissingDetails  = errors.New("certificate did not contain details")
+	ErrEmptySignature  = errors.New("empty signature")
+	ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
 )
+
+type ErrInvalidCertificateProperties struct {
+	str string
+}
+
+func NewErrInvalidCertificateProperties(format string, a ...any) error {
+	return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
+}
+
+func (e *ErrInvalidCertificateProperties) Error() string {
+	return e.str
+}

+ 141 - 0
cert/helper_test.go

@@ -0,0 +1,141 @@
+package cert
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 161 - 0
cert/pem.go

@@ -0,0 +1,161 @@
+package cert
+
+import (
+	"encoding/pem"
+	"fmt"
+
+	"golang.org/x/crypto/ed25519"
+)
+
+const (
+	CertificateBanner                = "NEBULA CERTIFICATE"
+	CertificateV2Banner              = "NEBULA CERTIFICATE V2"
+	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
+	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
+	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
+	Ed25519PrivateKeyBanner          = "NEBULA ED25519 PRIVATE KEY"
+	Ed25519PublicKeyBanner           = "NEBULA ED25519 PUBLIC KEY"
+
+	P256PrivateKeyBanner               = "NEBULA P256 PRIVATE KEY"
+	P256PublicKeyBanner                = "NEBULA P256 PUBLIC KEY"
+	EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY"
+	ECDSAP256PrivateKeyBanner          = "NEBULA ECDSA P256 PRIVATE KEY"
+)
+
+// UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed
+// data or an error on failure
+func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
+	p, r := pem.Decode(b)
+	if p == nil {
+		return nil, r, ErrInvalidPEMBlock
+	}
+
+	var c Certificate
+	var err error
+
+	switch p.Type {
+	// Implementations must validate the resulting certificate contains valid information
+	case CertificateBanner:
+		c, err = unmarshalCertificateV1(p.Bytes, nil)
+	case CertificateV2Banner:
+		c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
+	default:
+		return nil, r, ErrInvalidPEMCertificateBanner
+	}
+
+	if err != nil {
+		return nil, r, err
+	}
+
+	return c, r, nil
+
+}
+
+func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PublicKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PublicKeyBanner, Ed25519PublicKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PublicKeyBanner:
+		// Uncompressed
+		expectedLen = 65
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper public key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s public key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func MarshalPrivateKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: X25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: P256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+func MarshalSigningPrivateKeyToPEM(curve Curve, b []byte) []byte {
+	switch curve {
+	case Curve_CURVE25519:
+		return pem.EncodeToMemory(&pem.Block{Type: Ed25519PrivateKeyBanner, Bytes: b})
+	case Curve_P256:
+		return pem.EncodeToMemory(&pem.Block{Type: ECDSAP256PrivateKeyBanner, Bytes: b})
+	default:
+		return nil
+	}
+}
+
+// UnmarshalPrivateKeyFromPEM will try to unmarshal the first pem block in a byte array, returning any non
+// consumed data or an error on failure
+func UnmarshalPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var expectedLen int
+	var curve Curve
+	switch k.Type {
+	case X25519PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_CURVE25519
+	case P256PrivateKeyBanner:
+		expectedLen = 32
+		curve = Curve_P256
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper private key banner")
+	}
+	if len(k.Bytes) != expectedLen {
+		return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid %s private key", expectedLen, curve)
+	}
+	return k.Bytes, r, curve, nil
+}
+
+func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) {
+	k, r := pem.Decode(b)
+	if k == nil {
+		return nil, r, 0, fmt.Errorf("input did not contain a valid PEM encoded block")
+	}
+	var curve Curve
+	switch k.Type {
+	case EncryptedEd25519PrivateKeyBanner:
+		return nil, nil, Curve_CURVE25519, ErrPrivateKeyEncrypted
+	case EncryptedECDSAP256PrivateKeyBanner:
+		return nil, nil, Curve_P256, ErrPrivateKeyEncrypted
+	case Ed25519PrivateKeyBanner:
+		curve = Curve_CURVE25519
+		if len(k.Bytes) != ed25519.PrivateKeySize {
+			return nil, r, 0, fmt.Errorf("key was not %d bytes, is invalid Ed25519 private key", ed25519.PrivateKeySize)
+		}
+	case ECDSAP256PrivateKeyBanner:
+		curve = Curve_P256
+		if len(k.Bytes) != 32 {
+			return nil, r, 0, fmt.Errorf("key was not 32 bytes, is invalid ECDSA P256 private key")
+		}
+	default:
+		return nil, r, 0, fmt.Errorf("bytes did not contain a proper Ed25519/ECDSA private key banner")
+	}
+	return k.Bytes, r, curve, nil
+}

+ 292 - 0
cert/pem_test.go

@@ -0,0 +1,292 @@
+package cert
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestUnmarshalCertificateFromPEM(t *testing.T) {
+	goodCert := []byte(`
+# A good cert
+-----BEGIN NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-----END NEBULA CERTIFICATE-----
+`)
+	badBanner := []byte(`# A bad banner
+-----BEGIN NOT A NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-----END NOT A NEBULA CERTIFICATE-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA CERTIFICATE-----
+CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
+vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
+bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+-END NEBULA CERTIFICATE----`)
+
+	certBundle := appendByteSlices(goodCert, badBanner, invalidPem)
+
+	// Success test case
+	cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
+	assert.NotNil(t, cert)
+	assert.Equal(t, rest, append(badBanner, invalidPem...))
+	assert.Nil(t, err)
+
+	// Fail due to invalid banner.
+	cert, rest, err = UnmarshalCertificateFromPEM(rest)
+	assert.Nil(t, cert)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "bytes did not contain a proper certificate banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	cert, rest, err = UnmarshalCertificateFromPEM(rest)
+	assert.Nil(t, cert)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA ECDSA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ECDSA P256 PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+-----END NEBULA ED25519 PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NOT A NEBULA PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-END NEBULA ED25519 PRIVATE KEY-----`)
+
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalSigningPrivateKeyFromPEM(keyBundle)
+	assert.Len(t, k, 64)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+	assert.Nil(t, err)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
+	privKey := []byte(`# A good key
+-----BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA X25519 PRIVATE KEY-----
+`)
+	privP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PRIVATE KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA X25519 PRIVATE KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PRIVATE KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA X25519 PRIVATE KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA X25519 PRIVATE KEY-----`)
+
+	keyBundle := appendByteSlices(privKey, privP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPrivateKeyFromPEM(keyBundle)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Len(t, k, 32)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+	assert.Nil(t, err)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "bytes did not contain a proper private key banner")
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
+	pubKey := []byte(`# A good key
+-----BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA ED25519 PUBLIC KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA ED25519 PUBLIC KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PUBLIC KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA ED25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA ED25519 PUBLIC KEY-----`)
+
+	keyBundle := appendByteSlices(pubKey, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
+	assert.Equal(t, 32, len(k))
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.EqualError(t, err, "bytes did not contain a proper public key banner")
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}
+
+func TestUnmarshalX25519PublicKey(t *testing.T) {
+	pubKey := []byte(`# A good key
+-----BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA X25519 PUBLIC KEY-----
+`)
+	pubP256Key := []byte(`# A good key
+-----BEGIN NEBULA P256 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
+AAAAAAAAAAAAAAAAAAAAAAA=
+-----END NEBULA P256 PUBLIC KEY-----
+`)
+	shortKey := []byte(`# A short key
+-----BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==
+-----END NEBULA X25519 PUBLIC KEY-----
+`)
+	invalidBanner := []byte(`# Invalid banner
+-----BEGIN NOT A NEBULA PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-----END NOT A NEBULA PUBLIC KEY-----
+`)
+	invalidPem := []byte(`# Not a valid PEM format
+-BEGIN NEBULA X25519 PUBLIC KEY-----
+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
+-END NEBULA X25519 PUBLIC KEY-----`)
+
+	keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem)
+
+	// Success test case
+	k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
+	assert.Equal(t, 32, len(k))
+	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_CURVE25519, curve)
+
+	// Success test case
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Equal(t, 65, len(k))
+	assert.Nil(t, err)
+	assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
+	assert.Equal(t, Curve_P256, curve)
+
+	// Fail due to short key
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
+	assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
+
+	// Fail due to invalid banner
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.EqualError(t, err, "bytes did not contain a proper public key banner")
+	assert.Equal(t, rest, invalidPem)
+
+	// Fail due to ivalid PEM format, because
+	// it's missing the requisite pre-encapsulation boundary.
+	k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
+	assert.Nil(t, k)
+	assert.Equal(t, rest, invalidPem)
+	assert.EqualError(t, err, "input did not contain a valid PEM encoded block")
+}

+ 167 - 0
cert/sign.go

@@ -0,0 +1,167 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/sha256"
+	"fmt"
+	"math/big"
+	"net/netip"
+	"time"
+)
+
+// TBSCertificate represents a certificate intended to be signed.
+// It is invalid to use this structure as a Certificate.
+type TBSCertificate struct {
+	Version        Version
+	Name           string
+	Networks       []netip.Prefix
+	UnsafeNetworks []netip.Prefix
+	Groups         []string
+	IsCA           bool
+	NotBefore      time.Time
+	NotAfter       time.Time
+	PublicKey      []byte
+	Curve          Curve
+	issuer         string
+}
+
+type beingSignedCertificate interface {
+	// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
+	// Implementations must validate the resulting certificate contains valid information
+	fromTBSCertificate(*TBSCertificate) error
+
+	// marshalForSigning returns the bytes that should be signed
+	marshalForSigning() ([]byte, error)
+
+	// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
+	setSignature([]byte) error
+}
+
+type SignerLambda func(certBytes []byte) ([]byte, error)
+
+// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
+// details do not violate constraints of the signing certificate.
+// If the TBSCertificate is a CA then signer must be nil.
+func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
+	switch t.Curve {
+	case Curve_CURVE25519:
+		pk := ed25519.PrivateKey(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			sig := ed25519.Sign(pk, certBytes)
+			return sig, nil
+		}
+		return t.SignWith(signer, curve, sp)
+	case Curve_P256:
+		pk := &ecdsa.PrivateKey{
+			PublicKey: ecdsa.PublicKey{
+				Curve: elliptic.P256(),
+			},
+			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
+			D: new(big.Int).SetBytes(key),
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
+		pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			// We need to hash first for ECDSA
+			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
+			hashed := sha256.Sum256(certBytes)
+			return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
+		}
+		return t.SignWith(signer, curve, sp)
+	default:
+		return nil, fmt.Errorf("invalid curve: %s", t.Curve)
+	}
+}
+
+// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
+// You should only use SignWith if you do not have direct access to your private key.
+func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
+	if curve != t.Curve {
+		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+
+	if signer != nil {
+		if t.IsCA {
+			return nil, fmt.Errorf("can not sign a CA certificate with another")
+		}
+
+		err := checkCAConstraints(signer, t.NotBefore, t.NotAfter, t.Groups, t.Networks, t.UnsafeNetworks)
+		if err != nil {
+			return nil, err
+		}
+
+		issuer, err := signer.Fingerprint()
+		if err != nil {
+			return nil, fmt.Errorf("error computing issuer: %v", err)
+		}
+		t.issuer = issuer
+	} else {
+		if !t.IsCA {
+			return nil, fmt.Errorf("self signed certificates must have IsCA set to true")
+		}
+	}
+
+	var c beingSignedCertificate
+	switch t.Version {
+	case Version1:
+		c = &certificateV1{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	case Version2:
+		c = &certificateV2{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	default:
+		return nil, fmt.Errorf("unknown cert version %d", t.Version)
+	}
+
+	certBytes, err := c.marshalForSigning()
+	if err != nil {
+		return nil, err
+	}
+
+	sig, err := sp(certBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	err = c.setSignature(sig)
+	if err != nil {
+		return nil, err
+	}
+
+	sc, ok := c.(Certificate)
+	if !ok {
+		return nil, fmt.Errorf("invalid certificate")
+	}
+
+	return sc, nil
+}
+
+func comparePrefix(a, b netip.Prefix) int {
+	addr := a.Addr().Compare(b.Addr())
+	if addr == 0 {
+		return a.Bits() - b.Bits()
+	}
+	return addr
+}
+
+// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
+func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
+	if len(sortedPrefixes) < 2 {
+		return nil
+	}
+	for i := 1; i < len(sortedPrefixes); i++ {
+		if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
+			return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
+		}
+	}
+	return nil
+}

+ 90 - 0
cert/sign_test.go

@@ -0,0 +1,90 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCertificateV1_Sign(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/24"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+	}
+
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
+	assert.Nil(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	assert.Nil(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, uc)
+}
+
+func TestCertificateV1_SignP256(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/16"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+		Curve:     Curve_P256,
+	}
+
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	assert.NoError(t, err)
+	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
+	rawPriv := priv.D.FillBytes(make([]byte, 32))
+
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
+	assert.Nil(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	assert.Nil(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, uc)
+}

+ 138 - 0
cert_test/cert.go

@@ -0,0 +1,138 @@
+package cert_test
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case cert.Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &cert.TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case cert.Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &cert.TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 137 - 73
cmd/nebula-cert/ca.go

@@ -8,13 +8,14 @@ import (
 	"fmt"
 	"io"
 	"math"
-	"net"
+	"net/netip"
 	"os"
 	"strings"
 	"time"
 
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/ed25519"
 )
 
@@ -26,32 +27,43 @@ type caFlags struct {
 	outCertPath      *string
 	outQRPath        *string
 	groups           *string
-	ips              *string
-	subnets          *string
+	networks         *string
+	unsafeNetworks   *string
 	argonMemory      *uint
 	argonIterations  *uint
 	argonParallelism *uint
 	encryption       *bool
+	version          *uint
 
-	curve *string
+	curve  *string
+	p11url *string
+
+	// Deprecated options
+	ips     *string
+	subnets *string
 }
 
 func newCaFlags() *caFlags {
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf.set.Usage = func() {}
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
+	cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
-	cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
-	cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
+	cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
+	cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
+
+	cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
+	cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &cf
 }
 
@@ -76,17 +88,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		return err
 	}
 
+	isP11 := len(*cf.p11url) > 0
+
 	if err := mustFlagString("name", cf.name); err != nil {
 		return err
 	}
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
 	if err := mustFlagString("out-crt", cf.outCertPath); err != nil {
 		return err
 	}
 	var kdfParams *cert.Argon2Parameters
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		if kdfParams, err = parseArgonParameters(*cf.argonMemory, *cf.argonParallelism, *cf.argonIterations); err != nil {
 			return err
 		}
@@ -106,44 +122,57 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		}
 	}
 
-	var ips []*net.IPNet
-	if *cf.ips != "" {
-		for _, rs := range strings.Split(*cf.ips, ",") {
+	version := cert.Version(*cf.version)
+	if version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
+	var networks []netip.Prefix
+	if *cf.networks == "" && *cf.ips != "" {
+		// Pull up deprecated -ips flag if needed
+		*cf.networks = *cf.ips
+	}
+
+	if *cf.networks != "" {
+		for _, rs := range strings.Split(*cf.networks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				ip, ipNet, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid ip definition: %s", err)
+					return newHelpErrorf("invalid -networks definition: %s", rs)
 				}
-				if ip.To4() == nil {
-					return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
-
-				ipNet.IP = ip
-				ips = append(ips, ipNet)
+				networks = append(networks, n)
 			}
 		}
 	}
 
-	var subnets []*net.IPNet
-	if *cf.subnets != "" {
-		for _, rs := range strings.Split(*cf.subnets, ",") {
+	var unsafeNetworks []netip.Prefix
+	if *cf.unsafeNetworks == "" && *cf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*cf.unsafeNetworks = *cf.subnets
+	}
+
+	if *cf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				_, s, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
-				if s.IP.To4() == nil {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
-				subnets = append(subnets, s)
+				unsafeNetworks = append(unsafeNetworks, n)
 			}
 		}
 	}
 
 	var passphrase []byte
-	if *cf.encryption {
+	if !isP11 && *cf.encryption {
 		for i := 0; i < 5; i++ {
 			out.Write([]byte("Enter passphrase: "))
 			passphrase, err = pr.ReadPassword()
@@ -166,74 +195,109 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 
 	var curve cert.Curve
 	var pub, rawPriv []byte
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		curve = cert.Curve_CURVE25519
-		pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+
+		p11Client, err = pkclient.FromUrl(*cf.p11url)
 		if err != nil {
-			return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
 		}
-	case "P256":
-		var key *ecdsa.PrivateKey
-		curve = cert.Curve_P256
-		key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
 		if err != nil {
-			return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
 		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			curve = cert.Curve_CURVE25519
+			pub, rawPriv, err = ed25519.GenerateKey(rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ed25519 keys: %s", err)
+			}
+		case "P256":
+			var key *ecdsa.PrivateKey
+			curve = cert.Curve_P256
+			key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+			if err != nil {
+				return fmt.Errorf("error while generating ecdsa keys: %s", err)
+			}
 
-		// ecdh.PrivateKey lets us get at the encoded bytes, even though
-		// we aren't using ECDH here.
-		eKey, err := key.ECDH()
-		if err != nil {
-			return fmt.Errorf("error while converting ecdsa key: %s", err)
+			// ecdh.PrivateKey lets us get at the encoded bytes, even though
+			// we aren't using ECDH here.
+			eKey, err := key.ECDH()
+			if err != nil {
+				return fmt.Errorf("error while converting ecdsa key: %s", err)
+			}
+			rawPriv = eKey.Bytes()
+			pub = eKey.PublicKey().Bytes()
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
 		}
-		rawPriv = eKey.Bytes()
-		pub = eKey.PublicKey().Bytes()
 	}
 
-	nc := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      *cf.name,
-			Groups:    groups,
-			Ips:       ips,
-			Subnets:   subnets,
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(*cf.duration),
-			PublicKey: pub,
-			IsCA:      true,
-			Curve:     curve,
-		},
+	t := &cert.TBSCertificate{
+		Version:        version,
+		Name:           *cf.name,
+		Groups:         groups,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		NotBefore:      time.Now(),
+		NotAfter:       time.Now().Add(*cf.duration),
+		PublicKey:      pub,
+		IsCA:           true,
+		Curve:          curve,
 	}
 
-	if _, err := os.Stat(*cf.outKeyPath); err == nil {
-		return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+	if !isP11 {
+		if _, err := os.Stat(*cf.outKeyPath); err == nil {
+			return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath)
+		}
 	}
 
 	if _, err := os.Stat(*cf.outCertPath); err == nil {
 		return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath)
 	}
 
-	err = nc.Sign(curve, rawPriv)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
-	}
-
+	var c cert.Certificate
 	var b []byte
-	if *cf.encryption {
-		b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+
+	if isP11 {
+		c, err = t.SignWith(nil, curve, p11Client.SignASN1)
 		if err != nil {
-			return fmt.Errorf("error while encrypting out-key: %s", err)
+			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 		}
 	} else {
-		b = cert.MarshalSigningPrivateKey(curve, rawPriv)
-	}
+		c, err = t.Sign(nil, curve, rawPriv)
+		if err != nil {
+			return fmt.Errorf("error while signing: %s", err)
+		}
 
-	err = os.WriteFile(*cf.outKeyPath, b, 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+		if *cf.encryption {
+			b, err = cert.EncryptAndMarshalSigningPrivateKey(curve, rawPriv, passphrase, kdfParams)
+			if err != nil {
+				return fmt.Errorf("error while encrypting out-key: %s", err)
+			}
+		} else {
+			b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv)
+		}
+
+		err = os.WriteFile(*cf.outKeyPath, b, 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
 
-	b, err = nc.MarshalToPEM()
+	b, err = c.MarshalPEM()
 	if err != nil {
 		return fmt.Errorf("error while marshalling certificate: %s", err)
 	}

+ 36 - 30
cmd/nebula-cert/ca_test.go

@@ -16,8 +16,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//TODO: test file permissions
-
 func Test_caSummary(t *testing.T) {
 	assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
 }
@@ -43,17 +41,24 @@ func Test_caHelp(t *testing.T) {
 			"  -groups string\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"  -ips string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
+			"    	Deprecated, see -networks\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the certificate authority\n"+
+			"  -networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"  -out-key string\n"+
 			"    \tOptional: path to write the private key to (default \"ca.key\")\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use (default 2)\n",
 		ob.String(),
 	)
 }
@@ -82,25 +87,25 @@ func Test_ca(t *testing.T) {
 
 	// required args
 	assertHelpError(t, ca(
-		[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
+		[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
 	), "-name is required")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// ipv4 only ips
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// ipv4 only subnets
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
-	args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
+	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -108,12 +113,12 @@ func Test_ca(t *testing.T) {
 	// create temp key file
 	keyF, err := os.CreateTemp("", "test.key")
 	assert.Nil(t, err)
-	os.Remove(keyF.Name())
+	assert.Nil(t, os.Remove(keyF.Name()))
 
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -121,45 +126,46 @@ func Test_ca(t *testing.T) {
 	// create temp cert file
 	crtF, err := os.CreateTemp("", "test.crt")
 	assert.Nil(t, err)
-	os.Remove(crtF.Name())
-	os.Remove(keyF.Name())
+	assert.Nil(t, os.Remove(crtF.Name()))
+	assert.Nil(t, os.Remove(keyF.Name()))
 
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalEd25519PrivateKey(rb)
+	lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, c)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 64)
 
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
+	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 
-	assert.Equal(t, "test", lCrt.Details.Name)
-	assert.Len(t, lCrt.Details.Ips, 0)
-	assert.True(t, lCrt.Details.IsCA)
-	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
-	assert.Len(t, lCrt.Details.Subnets, 0)
-	assert.Len(t, lCrt.Details.PublicKey, 32)
-	assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
-	assert.Equal(t, "", lCrt.Details.Issuer)
-	assert.True(t, lCrt.CheckSignature(lCrt.Details.PublicKey))
+	assert.Equal(t, "test", lCrt.Name())
+	assert.Len(t, lCrt.Networks(), 0)
+	assert.True(t, lCrt.IsCA())
+	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
+	assert.Len(t, lCrt.UnsafeNetworks(), 0)
+	assert.Len(t, lCrt.PublicKey(), 32)
+	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
+	assert.Equal(t, "", lCrt.Issuer())
+	assert.True(t, lCrt.CheckSignature(lCrt.PublicKey()))
 
 	// test encrypted key
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
@@ -187,7 +193,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
@@ -197,7 +203,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, "", eb.String())
@@ -207,13 +213,13 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -222,7 +228,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())

+ 49 - 20
cmd/nebula-cert/keygen.go

@@ -6,6 +6,8 @@ import (
 	"io"
 	"os"
 
+	"github.com/slackhq/nebula/pkclient"
+
 	"github.com/slackhq/nebula/cert"
 )
 
@@ -13,8 +15,8 @@ type keygenFlags struct {
 	set        *flag.FlagSet
 	outKeyPath *string
 	outPubPath *string
-
-	curve *string
+	curve      *string
+	p11url     *string
 }
 
 func newKeygenFlags() *keygenFlags {
@@ -23,6 +25,7 @@ func newKeygenFlags() *keygenFlags {
 	cf.outPubPath = cf.set.String("out-pub", "", "Required: path to write the public key to")
 	cf.outKeyPath = cf.set.String("out-key", "", "Required: path to write the private key to")
 	cf.curve = cf.set.String("curve", "25519", "ECDH Curve (25519, P256)")
+	cf.p11url = p11Flag(cf.set)
 	return &cf
 }
 
@@ -33,32 +36,58 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error {
 		return err
 	}
 
-	if err := mustFlagString("out-key", cf.outKeyPath); err != nil {
-		return err
+	isP11 := len(*cf.p11url) > 0
+
+	if !isP11 {
+		if err = mustFlagString("out-key", cf.outKeyPath); err != nil {
+			return err
+		}
 	}
-	if err := mustFlagString("out-pub", cf.outPubPath); err != nil {
+	if err = mustFlagString("out-pub", cf.outPubPath); err != nil {
 		return err
 	}
 
 	var pub, rawPriv []byte
 	var curve cert.Curve
-	switch *cf.curve {
-	case "25519", "X25519", "Curve25519", "CURVE25519":
-		pub, rawPriv = x25519Keypair()
-		curve = cert.Curve_CURVE25519
-	case "P256":
-		pub, rawPriv = p256Keypair()
-		curve = cert.Curve_P256
-	default:
-		return fmt.Errorf("invalid curve: %s", *cf.curve)
+	if isP11 {
+		switch *cf.curve {
+		case "P256":
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve for PKCS#11: %s", *cf.curve)
+		}
+	} else {
+		switch *cf.curve {
+		case "25519", "X25519", "Curve25519", "CURVE25519":
+			pub, rawPriv = x25519Keypair()
+			curve = cert.Curve_CURVE25519
+		case "P256":
+			pub, rawPriv = p256Keypair()
+			curve = cert.Curve_P256
+		default:
+			return fmt.Errorf("invalid curve: %s", *cf.curve)
+		}
 	}
 
-	err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
-	if err != nil {
-		return fmt.Errorf("error while writing out-key: %s", err)
+	if isP11 {
+		p11Client, err := pkclient.FromUrl(*cf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key: %w", err)
+		}
+	} else {
+		err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
+		if err != nil {
+			return fmt.Errorf("error while writing out-key: %s", err)
+		}
 	}
-
-	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKey(curve, pub), 0600)
+	err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600)
 	if err != nil {
 		return fmt.Errorf("error while writing out-pub: %s", err)
 	}
@@ -72,7 +101,7 @@ func keygenSummary() string {
 
 func keygenHelp(out io.Writer) {
 	cf := newKeygenFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n"))
 	cf.set.SetOutput(out)
 	cf.set.PrintDefaults()
 }

+ 6 - 5
cmd/nebula-cert/keygen_test.go

@@ -9,8 +9,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//TODO: test file permissions
-
 func Test_keygenSummary(t *testing.T) {
 	assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
 }
@@ -26,7 +24,8 @@ func Test_keygenHelp(t *testing.T) {
 			"  -out-key string\n"+
 			"    \tRequired: path to write the private key to\n"+
 			"  -out-pub string\n"+
-			"    \tRequired: path to write the public key to\n",
+			"    \tRequired: path to write the public key to\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n"),
 		ob.String(),
 	)
 }
@@ -80,13 +79,15 @@ func Test_keygen(t *testing.T) {
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
+	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 
 	rb, _ = os.ReadFile(pubF.Name())
-	lPub, b, err := cert.UnmarshalX25519PublicKey(rb)
+	lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lPub, 32)

+ 10 - 3
cmd/nebula-cert/main_test.go

@@ -3,6 +3,7 @@ package main
 import (
 	"bytes"
 	"errors"
+	"fmt"
 	"io"
 	"os"
 	"testing"
@@ -10,8 +11,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//TODO: all flag parsing continueOnError will print to stderr on its own currently
-
 func Test_help(t *testing.T) {
 	expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
 		"  Global flags:\n" +
@@ -77,8 +76,16 @@ func assertHelpError(t *testing.T, err error, msg string) {
 	case *helpError:
 		// good
 	default:
-		t.Fatal("err was not a helpError")
+		t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
 	}
 
 	assert.EqualError(t, err, msg)
 }
+
+func optionalPkcs11String(msg string) string {
+	if p11Supported() {
+		return msg
+	} else {
+		return ""
+	}
+}

+ 15 - 0
cmd/nebula-cert/p11_cgo.go

@@ -0,0 +1,15 @@
+//go:build cgo && pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return true
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	return set.String("pkcs11", "", "Optional: PKCS#11 URI to an existing private key")
+}

+ 16 - 0
cmd/nebula-cert/p11_stub.go

@@ -0,0 +1,16 @@
+//go:build !cgo || !pkcs11
+
+package main
+
+import (
+	"flag"
+)
+
+func p11Supported() bool {
+	return false
+}
+
+func p11Flag(set *flag.FlagSet) *string {
+	var ret = ""
+	return &ret
+}

+ 14 - 9
cmd/nebula-cert/print.go

@@ -45,28 +45,27 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		return fmt.Errorf("unable to read cert; %s", err)
 	}
 
-	var c *cert.NebulaCertificate
+	var c cert.Certificate
 	var qrBytes []byte
 	part := 0
 
+	var jsonCerts []cert.Certificate
+
 	for {
-		c, rawCert, err = cert.UnmarshalNebulaCertificateFromPEM(rawCert)
+		c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
 		if err != nil {
 			return fmt.Errorf("error while unmarshaling cert: %s", err)
 		}
 
 		if *pf.json {
-			b, _ := json.Marshal(c)
-			out.Write(b)
-			out.Write([]byte("\n"))
-
+			jsonCerts = append(jsonCerts, c)
 		} else {
-			out.Write([]byte(c.String()))
-			out.Write([]byte("\n"))
+			_, _ = out.Write([]byte(c.String()))
+			_, _ = out.Write([]byte("\n"))
 		}
 
 		if *pf.outQRPath != "" {
-			b, err := c.MarshalToPEM()
+			b, err := c.MarshalPEM()
 			if err != nil {
 				return fmt.Errorf("error while marshalling cert to PEM: %s", err)
 			}
@@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		part++
 	}
 
+	if *pf.json {
+		b, _ := json.Marshal(jsonCerts)
+		_, _ = out.Write(b)
+		_, _ = out.Write([]byte("\n"))
+	}
+
 	if *pf.outQRPath != "" {
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		if err != nil {

+ 144 - 21
cmd/nebula-cert/print_test.go

@@ -2,6 +2,10 @@ package main
 
 import (
 	"bytes"
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
 	"os"
 	"testing"
 	"time"
@@ -68,25 +72,86 @@ func Test_printCert(t *testing.T) {
 	eb.Reset()
 	tf.Truncate(0)
 	tf.Seek(0, 0)
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test",
-			Groups:    []string{"hi"},
-			PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-		},
-		Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-	}
+	ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
+	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
 
-	p, _ := c.MarshalToPEM()
+	p, _ := c.MarshalPEM()
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 
 	err = printCert([]string{"-path", tf.Name()}, ob, eb)
+	fp, _ := c.Fingerprint()
+	pk := hex.EncodeToString(c.PublicKey())
+	sig := hex.EncodeToString(c.Signature())
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: \n\t\tPublic key: 0102030405060708090001020304050607080900010203040506070809000102\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\n\tSignature: 0102030405060708090001020304050607080900010203040506070809000102\n}\n",
+		//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		`{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+`,
 		ob.String(),
 	)
 	assert.Equal(t, "", eb.String())
@@ -96,26 +161,84 @@ func Test_printCert(t *testing.T) {
 	eb.Reset()
 	tf.Truncate(0)
 	tf.Seek(0, 0)
-	c = cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test",
-			Groups:    []string{"hi"},
-			PublicKey: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-		},
-		Signature: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2},
-	}
-
-	p, _ = c.MarshalToPEM()
 	tf.Write(p)
 	tf.Write(p)
 	tf.Write(p)
 
 	err = printCert([]string{"-json", "-path", tf.Name()}, ob, eb)
+	fp, _ = c.Fingerprint()
+	pk = hex.EncodeToString(c.PublicKey())
+	sig = hex.EncodeToString(c.Signature())
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\"0102030405060708090001020304050607080900010203040506070809000102\",\"subnets\":[]},\"fingerprint\":\"cc3492c0e9c48f17547f5987ea807462ebb3451e622590a10bb3763c344c82bd\",\"signature\":\"0102030405060708090001020304050607080900010203040506070809000102\"}\n",
+		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
+`,
 		ob.String(),
 	)
 	assert.Equal(t, "", eb.String())
 }
+
+// NewTestCaCert will generate a CA cert
+func NewTestCaCert(name string, pubKey, privKey []byte, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
+	var err error
+	if pubKey == nil || privKey == nil {
+		pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+	}
+
+	t := &cert.TBSCertificate{
+		Version:        cert.Version1,
+		Name:           name,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pubKey,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, cert.Curve_CURVE25519, privKey)
+	if err != nil {
+		panic(err)
+	}
+
+	return c, privKey
+}
+
+func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte) {
+	if before.IsZero() {
+		before = ca.NotBefore()
+	}
+
+	if after.IsZero() {
+		after = ca.NotAfter()
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	pub, rawPriv := x25519Keypair()
+	nc := &cert.TBSCertificate{
+		Version:        cert.Version1,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), signerKey)
+	if err != nil {
+		panic(err)
+	}
+
+	return c, rawPriv
+}

+ 238 - 110
cmd/nebula-cert/sign.go

@@ -3,50 +3,63 @@ package main
 import (
 	"crypto/ecdh"
 	"crypto/rand"
+	"errors"
 	"flag"
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"strings"
 	"time"
 
 	"github.com/skip2/go-qrcode"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 )
 
 type signFlags struct {
-	set         *flag.FlagSet
-	caKeyPath   *string
-	caCertPath  *string
-	name        *string
-	ip          *string
-	duration    *time.Duration
-	inPubPath   *string
-	outKeyPath  *string
-	outCertPath *string
-	outQRPath   *string
-	groups      *string
-	subnets     *string
+	set            *flag.FlagSet
+	version        *uint
+	caKeyPath      *string
+	caCertPath     *string
+	name           *string
+	networks       *string
+	unsafeNetworks *string
+	duration       *time.Duration
+	inPubPath      *string
+	outKeyPath     *string
+	outCertPath    *string
+	outQRPath      *string
+	groups         *string
+
+	p11url *string
+
+	// Deprecated options
+	ip      *string
+	subnets *string
 }
 
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
-	sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
+	sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
+	sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
-	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
-	return &sf
+	sf.p11url = p11Flag(sf.set)
 
+	sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
+	sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
+	return &sf
 }
 
 func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error {
@@ -56,8 +69,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return err
 	}
 
-	if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
-		return err
+	isP11 := len(*sf.p11url) > 0
+
+	if !isP11 {
+		if err := mustFlagString("ca-key", sf.caKeyPath); err != nil {
+			return err
+		}
 	}
 	if err := mustFlagString("ca-crt", sf.caCertPath); err != nil {
 		return err
@@ -65,50 +82,67 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	if err := mustFlagString("name", sf.name); err != nil {
 		return err
 	}
-	if err := mustFlagString("ip", sf.ip); err != nil {
-		return err
-	}
-	if *sf.inPubPath != "" && *sf.outKeyPath != "" {
+	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 
-	rawCAKey, err := os.ReadFile(*sf.caKeyPath)
-	if err != nil {
-		return fmt.Errorf("error while reading ca-key: %s", err)
+	var v4Networks []netip.Prefix
+	var v6Networks []netip.Prefix
+	if *sf.networks == "" && *sf.ip != "" {
+		// Pull up deprecated -ip flag if needed
+		*sf.networks = *sf.ip
+	}
+
+	if len(*sf.networks) == 0 {
+		return newHelpErrorf("-networks is required")
+	}
+
+	version := cert.Version(*sf.version)
+	if version != 0 && version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
 	}
 
 	var curve cert.Curve
 	var caKey []byte
 
-	// naively attempt to decode the private key as though it is not encrypted
-	caKey, _, curve, err = cert.UnmarshalSigningPrivateKey(rawCAKey)
-	if err == cert.ErrPrivateKeyEncrypted {
-		// ask for a passphrase until we get one
-		var passphrase []byte
-		for i := 0; i < 5; i++ {
-			out.Write([]byte("Enter passphrase: "))
-			passphrase, err = pr.ReadPassword()
-
-			if err == ErrNoTerminal {
-				return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
-			} else if err != nil {
-				return fmt.Errorf("error reading password: %s", err)
-			}
+	if !isP11 {
+		var rawCAKey []byte
+		rawCAKey, err := os.ReadFile(*sf.caKeyPath)
 
-			if len(passphrase) > 0 {
-				break
-			}
-		}
-		if len(passphrase) == 0 {
-			return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+		if err != nil {
+			return fmt.Errorf("error while reading ca-key: %s", err)
 		}
 
-		curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
-		if err != nil {
-			return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+		// naively attempt to decode the private key as though it is not encrypted
+		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
+		if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
+			// ask for a passphrase until we get one
+			var passphrase []byte
+			for i := 0; i < 5; i++ {
+				out.Write([]byte("Enter passphrase: "))
+				passphrase, err = pr.ReadPassword()
+
+				if errors.Is(err, ErrNoTerminal) {
+					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
+				} else if err != nil {
+					return fmt.Errorf("error reading password: %s", err)
+				}
+
+				if len(passphrase) > 0 {
+					break
+				}
+			}
+			if len(passphrase) == 0 {
+				return fmt.Errorf("cannot open encrypted ca-key without passphrase")
+			}
+
+			curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey)
+			if err != nil {
+				return fmt.Errorf("error while parsing encrypted ca-key: %s", err)
+			}
+		} else if err != nil {
+			return fmt.Errorf("error while parsing ca-key: %s", err)
 		}
-	} else if err != nil {
-		return fmt.Errorf("error while parsing ca-key: %s", err)
 	}
 
 	rawCACert, err := os.ReadFile(*sf.caCertPath)
@@ -116,18 +150,15 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("error while reading ca-crt: %s", err)
 	}
 
-	caCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCACert)
+	caCert, _, err := cert.UnmarshalCertificateFromPEM(rawCACert)
 	if err != nil {
 		return fmt.Errorf("error while parsing ca-crt: %s", err)
 	}
 
-	if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate does not match private key")
-	}
-
-	issuer, err := caCert.Sha256Sum()
-	if err != nil {
-		return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err)
+	if !isP11 {
+		if err := caCert.VerifyPrivateKey(curve, caKey); err != nil {
+			return fmt.Errorf("refusing to sign, root certificate does not match private key")
+		}
 	}
 
 	if caCert.Expired(time.Now()) {
@@ -136,82 +167,99 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 	// if no duration is given, expire one second before the root expires
 	if *sf.duration <= 0 {
-		*sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1
+		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
 	}
 
-	ip, ipNet, err := net.ParseCIDR(*sf.ip)
-	if err != nil {
-		return newHelpErrorf("invalid ip definition: %s", err)
-	}
-	if ip.To4() == nil {
-		return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
-	}
-	ipNet.IP = ip
+	if *sf.networks != "" {
+		for _, rs := range strings.Split(*sf.networks, ",") {
+			rs := strings.Trim(rs, " ")
+			if rs != "" {
+				n, err := netip.ParsePrefix(rs)
+				if err != nil {
+					return newHelpErrorf("invalid -networks definition: %s", rs)
+				}
 
-	groups := []string{}
-	if *sf.groups != "" {
-		for _, rg := range strings.Split(*sf.groups, ",") {
-			g := strings.TrimSpace(rg)
-			if g != "" {
-				groups = append(groups, g)
+				if n.Addr().Is4() {
+					v4Networks = append(v4Networks, n)
+				} else {
+					v6Networks = append(v6Networks, n)
+				}
 			}
 		}
 	}
 
-	subnets := []*net.IPNet{}
-	if *sf.subnets != "" {
-		for _, rs := range strings.Split(*sf.subnets, ",") {
+	var v4UnsafeNetworks []netip.Prefix
+	var v6UnsafeNetworks []netip.Prefix
+	if *sf.unsafeNetworks == "" && *sf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*sf.unsafeNetworks = *sf.subnets
+	}
+
+	if *sf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				_, s, err := net.ParseCIDR(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
-				if s.IP.To4() == nil {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+
+				if n.Addr().Is4() {
+					v4UnsafeNetworks = append(v4UnsafeNetworks, n)
+				} else {
+					v6UnsafeNetworks = append(v6UnsafeNetworks, n)
 				}
-				subnets = append(subnets, s)
+			}
+		}
+	}
+
+	var groups []string
+	if *sf.groups != "" {
+		for _, rg := range strings.Split(*sf.groups, ",") {
+			g := strings.TrimSpace(rg)
+			if g != "" {
+				groups = append(groups, g)
 			}
 		}
 	}
 
 	var pub, rawPriv []byte
+	var p11Client *pkclient.PKClient
+
+	if isP11 {
+		curve = cert.Curve_P256
+		p11Client, err = pkclient.FromUrl(*sf.p11url)
+		if err != nil {
+			return fmt.Errorf("error while creating PKCS#11 client: %w", err)
+		}
+		defer func(client *pkclient.PKClient) {
+			_ = client.Close()
+		}(p11Client)
+	}
+
 	if *sf.inPubPath != "" {
+		var pubCurve cert.Curve
 		rawPub, err := os.ReadFile(*sf.inPubPath)
 		if err != nil {
 			return fmt.Errorf("error while reading in-pub: %s", err)
 		}
-		var pubCurve cert.Curve
-		pub, _, pubCurve, err = cert.UnmarshalPublicKey(rawPub)
+
+		pub, _, pubCurve, err = cert.UnmarshalPublicKeyFromPEM(rawPub)
 		if err != nil {
 			return fmt.Errorf("error while parsing in-pub: %s", err)
 		}
 		if pubCurve != curve {
 			return fmt.Errorf("curve of in-pub does not match ca")
 		}
+	} else if isP11 {
+		pub, err = p11Client.GetPubKey()
+		if err != nil {
+			return fmt.Errorf("error while getting public key with PKCS#11: %w", err)
+		}
 	} else {
 		pub, rawPriv = newKeypair(curve)
 	}
 
-	nc := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      *sf.name,
-			Ips:       []*net.IPNet{ipNet},
-			Groups:    groups,
-			Subnets:   subnets,
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(*sf.duration),
-			PublicKey: pub,
-			IsCA:      false,
-			Issuer:    issuer,
-			Curve:     curve,
-		},
-	}
-
-	if err := nc.CheckRootConstrains(caCert); err != nil {
-		return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err)
-	}
-
 	if *sf.outKeyPath == "" {
 		*sf.outKeyPath = *sf.name + ".key"
 	}
@@ -224,25 +272,105 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 
-	err = nc.Sign(curve, caKey)
-	if err != nil {
-		return fmt.Errorf("error while signing: %s", err)
+	var crts []cert.Certificate
+
+	notBefore := time.Now()
+	notAfter := notBefore.Add(*sf.duration)
+
+	if version == 0 || version == cert.Version1 {
+		// Make sure we at least have an ip
+		if len(v4Networks) != 1 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
+		}
+
+		if version == cert.Version1 {
+			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
+			if len(v6Networks) > 0 {
+				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
+			}
+
+			if len(v6UnsafeNetworks) > 0 {
+				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+			}
+		}
+
+		t := &cert.TBSCertificate{
+			Version:        cert.Version1,
+			Name:           *sf.name,
+			Networks:       []netip.Prefix{v4Networks[0]},
+			Groups:         groups,
+			UnsafeNetworks: v4UnsafeNetworks,
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
+	}
+
+	if version == 0 || version == cert.Version2 {
+		t := &cert.TBSCertificate{
+			Version:        cert.Version2,
+			Name:           *sf.name,
+			Networks:       append(v4Networks, v6Networks...),
+			Groups:         groups,
+			UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
 	}
 
-	if *sf.inPubPath == "" {
+	if !isP11 && *sf.inPubPath == "" {
 		if _, err := os.Stat(*sf.outKeyPath); err == nil {
 			return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath)
 		}
 
-		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKey(curve, rawPriv), 0600)
+		err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600)
 		if err != nil {
 			return fmt.Errorf("error while writing out-key: %s", err)
 		}
 	}
 
-	b, err := nc.MarshalToPEM()
-	if err != nil {
-		return fmt.Errorf("error while marshalling certificate: %s", err)
+	var b []byte
+	for _, c := range crts {
+		sb, err := c.MarshalPEM()
+		if err != nil {
+			return fmt.Errorf("error while marshalling certificate: %s", err)
+		}
+		b = append(b, sb...)
 	}
 
 	err = os.WriteFile(*sf.outCertPath, b, 0600)

+ 74 - 75
cmd/nebula-cert/sign_test.go

@@ -16,8 +16,6 @@ import (
 	"golang.org/x/crypto/ed25519"
 )
 
-//TODO: test file permissions
-
 func Test_signSummary(t *testing.T) {
 	assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
 }
@@ -39,17 +37,24 @@ func Test_signHelp(t *testing.T) {
 			"  -in-pub string\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"  -ip string\n"+
-			"    \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
+			"    \tDeprecated, see -networks\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
+			"  -networks string\n"+
+			"    \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"  -out-key string\n"+
 			"    \tOptional (if in-pub not set): path to write the private key to\n"+
 			"  -out-qr string\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
+			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
 		ob.String(),
 	)
 }
@@ -76,20 +81,20 @@ func Test_signCert(t *testing.T) {
 
 	// required args
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
 	), "-name is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
-	), "-ip is required")
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-networks is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// cannot set -in-pub and -out-key
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
 	), "cannot set both -in-pub and -out-key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -97,7 +102,7 @@ func Test_signCert(t *testing.T) {
 	// failed to read key
 	ob.Reset()
 	eb.Reset()
-	args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 	// failed to unmarshal key
@@ -107,7 +112,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -116,10 +121,10 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
-	caKeyF.Write(cert.MarshalEd25519PrivateKey(caPriv))
+	caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
 
 	// failed to read cert
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -131,26 +136,18 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// write a proper ca cert for later
-	ca := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(time.Minute * 200),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	b, _ := ca.MarshalToPEM()
+	ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
+	b, _ := ca.MarshalPEM()
 	caCrtF.Write(b)
 
 	// failed to read pub
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -162,7 +159,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(inPubF.Name())
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -171,35 +168,42 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	inPub, _ := x25519Keypair()
-	inPubF.Write(cert.MarshalX25519PublicKey(inPub))
+	inPubF.Write(cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub))
 
 	// bad ip cidr
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: invalid CIDR address: a1.1.1.1/24")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// bad subnet cidr
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: invalid CIDR address: a")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -208,11 +212,11 @@ func Test_signCert(t *testing.T) {
 	caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF2.Name())
-	caKeyF2.Write(cert.MarshalEd25519PrivateKey(caPriv2))
+	caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -220,7 +224,7 @@ func Test_signCert(t *testing.T) {
 	// failed key write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -233,7 +237,7 @@ func Test_signCert(t *testing.T) {
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -247,40 +251,41 @@ func Test_signCert(t *testing.T) {
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// read cert and key files
 	rb, _ := os.ReadFile(keyF.Name())
-	lKey, b, err := cert.UnmarshalX25519PrivateKey(rb)
+	lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
+	assert.Equal(t, cert.Curve_CURVE25519, curve)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 	assert.Len(t, lKey, 32)
 
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err := cert.UnmarshalNebulaCertificateFromPEM(rb)
+	lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
 
-	assert.Equal(t, "test", lCrt.Details.Name)
-	assert.Equal(t, "1.1.1.1/24", lCrt.Details.Ips[0].String())
-	assert.Len(t, lCrt.Details.Ips, 1)
-	assert.False(t, lCrt.Details.IsCA)
-	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Details.Groups)
-	assert.Len(t, lCrt.Details.Subnets, 3)
-	assert.Len(t, lCrt.Details.PublicKey, 32)
-	assert.Equal(t, time.Duration(time.Minute*100), lCrt.Details.NotAfter.Sub(lCrt.Details.NotBefore))
+	assert.Equal(t, "test", lCrt.Name())
+	assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
+	assert.Len(t, lCrt.Networks(), 1)
+	assert.False(t, lCrt.IsCA())
+	assert.Equal(t, []string{"1", "2", "3", "4", "5"}, lCrt.Groups())
+	assert.Len(t, lCrt.UnsafeNetworks(), 3)
+	assert.Len(t, lCrt.PublicKey(), 32)
+	assert.Equal(t, time.Duration(time.Minute*100), lCrt.NotAfter().Sub(lCrt.NotBefore()))
 
 	sns := []string{}
-	for _, sn := range lCrt.Details.Subnets {
+	for _, sn := range lCrt.UnsafeNetworks() {
 		sns = append(sns, sn.String())
 	}
 	assert.Equal(t, []string{"10.1.1.1/32", "10.2.2.2/32", "10.5.5.5/32"}, sns)
 
-	issuer, _ := ca.Sha256Sum()
-	assert.Equal(t, issuer, lCrt.Details.Issuer)
+	issuer, _ := ca.Fingerprint()
+	assert.Equal(t, issuer, lCrt.Issuer())
 
 	assert.True(t, lCrt.CheckSignature(caPub))
 
@@ -289,37 +294,39 @@ func Test_signCert(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// read cert file and check pub key matches in-pub
 	rb, _ = os.ReadFile(crtF.Name())
-	lCrt, b, err = cert.UnmarshalNebulaCertificateFromPEM(rb)
+	lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
 	assert.Len(t, b, 0)
 	assert.Nil(t, err)
-	assert.Equal(t, lCrt.Details.PublicKey, inPub)
+	assert.Equal(t, lCrt.PublicKey(), inPub)
 
 	// test refuse to sign cert with duration beyond root
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
-	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate constraints violated: certificate expires after signing certificate")
+	os.Remove(keyF.Name())
+	os.Remove(crtF.Name())
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -327,14 +334,14 @@ func Test_signCert(t *testing.T) {
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -361,20 +368,12 @@ func Test_signCert(t *testing.T) {
 	b, _ = cert.EncryptAndMarshalSigningPrivateKey(cert.Curve_CURVE25519, caPriv, passphrase, kdfParams)
 	caKeyF.Write(b)
 
-	ca = cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: time.Now(),
-			NotAfter:  time.Now().Add(time.Minute * 200),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	b, _ = ca.MarshalToPEM()
+	ca, _ = NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil)
+	b, _ = ca.MarshalPEM()
 	caCrtF.Write(b)
 
 	// test with the proper password
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
@@ -384,7 +383,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	testpw.password = []byte("invalid password")
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
@@ -393,7 +392,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, nopw))
 	// normally the user hitting enter on the prompt would add newlines between these
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@@ -403,7 +402,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, errpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())

+ 26 - 15
cmd/nebula-cert/verify.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
-		return fmt.Errorf("error while reading ca: %s", err)
+		return fmt.Errorf("error while reading ca: %w", err)
 	}
 
 	caPool := cert.NewCAPool()
 	for {
-		rawCACert, err = caPool.AddCACertificate(rawCACert)
+		rawCACert, err = caPool.AddCAFromPEM(rawCACert)
 		if err != nil {
-			return fmt.Errorf("error while adding ca cert to pool: %s", err)
+			return fmt.Errorf("error while adding ca cert to pool: %w", err)
 		}
 
 		if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
-		return fmt.Errorf("unable to read crt; %s", err)
+		return fmt.Errorf("unable to read crt: %w", err)
 	}
-
-	c, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
-	if err != nil {
-		return fmt.Errorf("error while parsing crt: %s", err)
-	}
-
-	good, err := c.Verify(time.Now(), caPool)
-	if !good {
-		return err
+	var errs []error
+	for {
+		if len(rawCert) == 0 {
+			break
+		}
+		c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
+		if err != nil {
+			return fmt.Errorf("error while parsing crt: %w", err)
+		}
+		rawCert = extra
+		_, err = caPool.VerifyCertificate(time.Now(), c)
+		if err != nil {
+			switch {
+			case errors.Is(err, cert.ErrCaNotFound):
+				errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
+			default:
+				errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
+			}
+		}
 	}
 
-	return nil
+	return errors.Join(errs...)
 }
 
 func verifySummary() string {
@@ -80,7 +91,7 @@ func verifySummary() string {
 
 func verifyHelp(out io.Writer) {
 	vf := newVerifyFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
 	vf.set.SetOutput(out)
 	vf.set.PrintDefaults()
 }

+ 13 - 30
cmd/nebula-cert/verify_test.go

@@ -3,6 +3,7 @@ package main
 import (
 	"bytes"
 	"crypto/rand"
+	"errors"
 	"os"
 	"testing"
 	"time"
@@ -67,17 +68,8 @@ func Test_verify(t *testing.T) {
 
 	// make a ca for later
 	caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
-	ca := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test-ca",
-			NotBefore: time.Now().Add(time.Hour * -1),
-			NotAfter:  time.Now().Add(time.Hour * 2),
-			PublicKey: caPub,
-			IsCA:      true,
-		},
-	}
-	ca.Sign(cert.Curve_CURVE25519, caPriv)
-	b, _ := ca.MarshalToPEM()
+	ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil)
+	b, _ := ca.MarshalPEM()
 	caFile.Truncate(0)
 	caFile.Seek(0, 0)
 	caFile.Write(b)
@@ -86,7 +78,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
+	assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 	// invalid crt at path
 	ob.Reset()
@@ -102,22 +94,13 @@ func Test_verify(t *testing.T) {
 	assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
 
 	// unverifiable cert at path
-	_, badPriv, _ := ed25519.GenerateKey(rand.Reader)
-	certPub, _ := x25519Keypair()
-	signer, _ := ca.Sha256Sum()
-	crt := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "test-cert",
-			NotBefore: time.Now().Add(time.Hour * -1),
-			NotAfter:  time.Now().Add(time.Hour),
-			PublicKey: certPub,
-			IsCA:      false,
-			Issuer:    signer,
-		},
+	crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
+	// Slightly evil hack to modify the certificate after it was sealed to generate an invalid signature
+	pub := crt.PublicKey()
+	for i, _ := range pub {
+		pub[i] = 0
 	}
-
-	crt.Sign(cert.Curve_CURVE25519, badPriv)
-	b, _ = crt.MarshalToPEM()
+	b, _ = crt.MarshalPEM()
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
 	certFile.Write(b)
@@ -125,11 +108,11 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "certificate signature did not match")
+	assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
 
 	// verified cert at path
-	crt.Sign(cert.Curve_CURVE25519, caPriv)
-	b, _ = crt.MarshalToPEM()
+	crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
+	b, _ = crt.MarshalPEM()
 	certFile.Truncate(0)
 	certFile.Seek(0, 0)
 	certFile.Write(b)

+ 0 - 3
config/config_test.go

@@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) {
 		"new": "hi",
 	}
 	assert.Equal(t, expected, c.Settings)
-
-	//TODO: test symlinked file
-	//TODO: test symlinked directory
 }
 
 func TestConfig_Get(t *testing.T) {

+ 61 - 36
connection_manager.go

@@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 	case deleteTunnel:
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 
 	case closeTunnel:
@@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 	for _, r := range relayFor {
-		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
 
 		var index uint32
 		var relayFrom netip.Addr
@@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = existing.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = existing.PeerAddr
 			case ForwardingType:
-				relayFrom = existing.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = existing.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
 			}
@@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			n.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = r.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = r.PeerAddr
 			case ForwardingType:
-				relayFrom = r.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = r.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
 			}
 		}
 
-		//TODO: IPV6-WORK
-		relayFromB := relayFrom.As4()
-		relayToB := relayTo.As4()
-
 		// Send a CreateRelayRequest to the peer.
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         binary.BigEndian.Uint32(relayFromB[:]),
-			RelayToIp:           binary.BigEndian.Uint32(relayToB[:]),
 		}
+
+		switch newhostinfo.GetCert().Certificate.Version() {
+		case cert.Version1:
+			if !relayFrom.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				continue
+			}
+
+			if !relayTo.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				continue
+			}
+
+			b := relayFrom.As4()
+			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = relayTo.As4()
+			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		case cert.Version2:
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
+			req.RelayToAddr = netAddrToProtoAddr(relayTo)
+		default:
+			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			continue
+		}
+
 		msg, err := req.Marshal()
 		if err != nil {
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           req.RelayFromIp,
-				"relayTo":             req.RelayToIp,
+				"relayFrom":           req.RelayFromAddr,
+				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
-				"vpnIp":               newhostinfo.vpnIp}).
+				"vpnAddrs":            newhostinfo.vpnAddrs}).
 				Info("send CreateRelayRequest")
 		}
 	}
@@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		return closeTunnel, hostinfo, nil
 	}
 
-	primary := n.hostMap.Hosts[hostinfo.vpnIp]
+	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
@@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 
-	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
-		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
-		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
-		// The remotes vpn ip is lower than mine. I will not flip.
+	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
+	// vpn addr is static across all tunnels for this host pair so lets
+	// use that to determine if we should consider swapping.
+	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
+		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 	}
 
-	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
+	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
+	// settle down.
+	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 	n.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnIp] == primary {
+	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
 		n.hostMap.unlockedMakePrimary(current)
 	}
 	n.hostMap.Unlock()
@@ -436,8 +458,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
-	if valid {
+	caPool := n.intf.pki.GetCAPool()
+	err := caPool.VerifyCachedCertificate(now, remoteCert)
+	if err == nil {
 		return false
 	}
 
@@ -446,9 +469,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	fingerprint, _ := remoteCert.Sha256Sum()
 	hostinfo.logger(n.l).WithError(err).
-		WithField("fingerprint", fingerprint).
+		WithField("fingerprint", remoteCert.Fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
 	return true
@@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
+	cs := n.intf.pki.getCertState()
+	curCrt := hostinfo.ConnectionState.myCert
+	myCrt := cs.getCertificate(curCrt.Version())
+	if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}
 
-	n.l.WithField("vpnIp", hostinfo.vpnIp).
+	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }

+ 162 - 61
connection_manager_test.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"crypto/ed25519"
 	"crypto/rand"
-	"net"
 	"net/netip"
 	"testing"
 	"time"
@@ -35,20 +34,19 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -75,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &cert.NebulaCertificate{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -89,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 
@@ -106,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -158,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &cert.NebulaCertificate{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -171,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@@ -188,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
@@ -197,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
 
 // Check if we can disconnect the peer.
@@ -206,55 +203,48 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	now := time.Now()
 	l := test.NewLogger()
-	ipNet := net.IPNet{
-		IP:   net.IPv4(172, 1, 1, 2),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
+
 	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	// Generate keys for CA and peer's cert.
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
-	caCert := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "ca",
-			NotBefore: now,
-			NotAfter:  now.Add(1 * time.Hour),
-			IsCA:      true,
-			PublicKey: pubCA,
-		},
+	tbs := &cert.TBSCertificate{
+		Version:   1,
+		Name:      "ca",
+		IsCA:      true,
+		NotBefore: now,
+		NotAfter:  now.Add(1 * time.Hour),
+		PublicKey: pubCA,
 	}
 
-	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
-	ncp := &cert.NebulaCAPool{
-		CAs: cert.NewCAPool().CAs,
-	}
-	ncp.CAs["ca"] = &caCert
+	caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, err)
+	ncp := cert.NewCAPool()
+	assert.NoError(t, ncp.AddCA(caCert))
 
 	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
-	peerCert := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:      "host",
-			Ips:       []*net.IPNet{&ipNet},
-			Subnets:   []*net.IPNet{},
-			NotBefore: now,
-			NotAfter:  now.Add(60 * time.Second),
-			PublicKey: pubCrt,
-			IsCA:      false,
-			Issuer:    "ca",
-		},
+	tbs = &cert.TBSCertificate{
+		Version:   1,
+		Name:      "host",
+		Networks:  []netip.Prefix{vpncidr},
+		NotBefore: now,
+		NotAfter:  now.Add(60 * time.Second),
+		PublicKey: pubCrt,
 	}
-	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
+	peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, err)
+
+	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -280,10 +270,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.connectionManager = nc
 
 	hostinfo := &HostInfo{
-		vpnIp: vpnIp,
+		vpnAddrs: []netip.Addr{vpnIp},
 		ConnectionState: &ConnectionState{
-			myCert:   &cert.NebulaCertificate{},
-			peerCert: &peerCert,
+			myCert:   &dummyCert{},
+			peerCert: cachedPeerCert,
 			H:        &noise.HandshakeState{},
 		},
 	}
@@ -303,3 +293,114 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
 	assert.True(t, invalid)
 }
+
+type dummyCert struct {
+	version        cert.Version
+	curve          cert.Curve
+	groups         []string
+	isCa           bool
+	issuer         string
+	name           string
+	networks       []netip.Prefix
+	notAfter       time.Time
+	notBefore      time.Time
+	publicKey      []byte
+	signature      []byte
+	unsafeNetworks []netip.Prefix
+}
+
+func (d *dummyCert) Version() cert.Version {
+	return d.version
+}
+
+func (d *dummyCert) Curve() cert.Curve {
+	return d.curve
+}
+
+func (d *dummyCert) Groups() []string {
+	return d.groups
+}
+
+func (d *dummyCert) IsCA() bool {
+	return d.isCa
+}
+
+func (d *dummyCert) Issuer() string {
+	return d.issuer
+}
+
+func (d *dummyCert) Name() string {
+	return d.name
+}
+
+func (d *dummyCert) Networks() []netip.Prefix {
+	return d.networks
+}
+
+func (d *dummyCert) NotAfter() time.Time {
+	return d.notAfter
+}
+
+func (d *dummyCert) NotBefore() time.Time {
+	return d.notBefore
+}
+
+func (d *dummyCert) PublicKey() []byte {
+	return d.publicKey
+}
+
+func (d *dummyCert) Signature() []byte {
+	return d.signature
+}
+
+func (d *dummyCert) UnsafeNetworks() []netip.Prefix {
+	return d.unsafeNetworks
+}
+
+func (d *dummyCert) MarshalForHandshakes() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Sign(curve cert.Curve, key []byte) error {
+	return nil
+}
+
+func (d *dummyCert) CheckSignature(key []byte) bool {
+	return true
+}
+
+func (d *dummyCert) Expired(t time.Time) bool {
+	return false
+}
+
+func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error {
+	return nil
+}
+
+func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error {
+	return nil
+}
+
+func (d *dummyCert) String() string {
+	return ""
+}
+
+func (d *dummyCert) Marshal() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) MarshalPEM() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Fingerprint() (string, error) {
+	return "", nil
+}
+
+func (d *dummyCert) MarshalJSON() ([]byte, error) {
+	return nil, nil
+}
+
+func (d *dummyCert) Copy() cert.Certificate {
+	return d
+}

+ 32 - 23
connection_state.go

@@ -3,6 +3,7 @@ package nebula
 import (
 	"crypto/rand"
 	"encoding/json"
+	"fmt"
 	"sync"
 	"sync/atomic"
 
@@ -18,50 +19,54 @@ type ConnectionState struct {
 	eKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	H              *noise.HandshakeState
-	myCert         *cert.NebulaCertificate
-	peerCert       *cert.NebulaCertificate
+	myCert         cert.Certificate
+	peerCert       *cert.CachedCertificate
 	initiator      bool
 	messageCounter atomic.Uint64
 	window         *Bits
 	writeLock      sync.Mutex
 }
 
-func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
 	var dhFunc noise.DHFunc
-	switch certState.Certificate.Details.Curve {
+	switch crt.Curve() {
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
-		dhFunc = noiseutil.DHP256
+		if cs.pkcs11Backed {
+			dhFunc = noiseutil.DHP256PKCS11
+		} else {
+			dhFunc = noiseutil.DHP256
+		}
 	default:
-		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
-		return nil
+		return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
 	}
 
-	var cs noise.CipherSuite
-	if cipher == "chachapoly" {
-		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	var ncs noise.CipherSuite
+	if cs.cipher == "chachapoly" {
+		ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	} else {
-		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
+		ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 
-	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
+	static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
 
 	b := NewBits(ReplayWindow)
-	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
+	// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
 	b.Update(l, 0)
 
 	hs, err := noise.NewHandshakeState(noise.Config{
-		CipherSuite:           cs,
-		Random:                rand.Reader,
-		Pattern:               pattern,
-		Initiator:             initiator,
-		StaticKeypair:         static,
-		PresharedKey:          psk,
-		PresharedKeyPlacement: pskStage,
+		CipherSuite:   ncs,
+		Random:        rand.Reader,
+		Pattern:       pattern,
+		Initiator:     initiator,
+		StaticKeypair: static,
+		//NOTE: These should come from CertState (pki.go) when we finally implement it
+		PresharedKey:          []byte{},
+		PresharedKeyPlacement: 0,
 	})
 	if err != nil {
-		return nil
+		return nil, fmt.Errorf("NewConnectionState: %s", err)
 	}
 
 	// The queue and ready params prevent a counter race that would happen when
@@ -70,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		initiator: initiator,
 		window:    b,
-		myCert:    certState.Certificate,
+		myCert:    crt,
 	}
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
 	ci.messageCounter.Add(2)
 
-	return ci
+	return ci, nil
 }
 
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@@ -85,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"message_counter": cs.messageCounter.Load(),
 	})
 }
+
+func (cs *ConnectionState) Curve() cert.Curve {
+	return cs.myCert.Curve()
+}

+ 37 - 36
control.go

@@ -19,9 +19,9 @@ import (
 type controlEach func(h *HostInfo)
 
 type controlHostLister interface {
-	QueryVpnIp(vpnIp netip.Addr) *HostInfo
+	QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
-	ForEachVpnIp(each controlEach)
+	ForEachVpnAddr(each controlEach)
 	GetPreferredRanges() []netip.Prefix
 }
 
@@ -37,15 +37,15 @@ type Control struct {
 }
 
 type ControlHostInfo struct {
-	VpnIp                  netip.Addr              `json:"vpnIp"`
-	LocalIndex             uint32                  `json:"localIndex"`
-	RemoteIndex            uint32                  `json:"remoteIndex"`
-	RemoteAddrs            []netip.AddrPort        `json:"remoteAddrs"`
-	Cert                   *cert.NebulaCertificate `json:"cert"`
-	MessageCounter         uint64                  `json:"messageCounter"`
-	CurrentRemote          netip.AddrPort          `json:"currentRemote"`
-	CurrentRelaysToMe      []netip.Addr            `json:"currentRelaysToMe"`
-	CurrentRelaysThroughMe []netip.Addr            `json:"currentRelaysThroughMe"`
+	VpnAddrs               []netip.Addr     `json:"vpnAddrs"`
+	LocalIndex             uint32           `json:"localIndex"`
+	RemoteIndex            uint32           `json:"remoteIndex"`
+	RemoteAddrs            []netip.AddrPort `json:"remoteAddrs"`
+	Cert                   cert.Certificate `json:"cert"`
+	MessageCounter         uint64           `json:"messageCounter"`
+	CurrentRemote          netip.AddrPort   `json:"currentRemote"`
+	CurrentRelaysToMe      []netip.Addr     `json:"currentRelaysToMe"`
+	CurrentRelaysThroughMe []netip.Addr     `json:"currentRelaysThroughMe"`
 }
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -130,15 +130,18 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 }
 
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
-func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate {
-	if c.f.myVpnNet.Addr() == vpnIp {
-		return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
+	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
+	if found {
+		// Only returning the default certificate since its impossible
+		// for any other host but ourselves to have more than 1
+		return c.f.pki.getCertState().GetDefaultCertificate().Copy()
 	}
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 		return nil
 	}
-	return hi.GetCert()
+	return hi.GetCert().Certificate.Copy()
 }
 
 // CreateTunnel creates a new tunnel to the given vpn ip.
@@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
 
 // PrintTunnel creates a new tunnel to the given vpn ip.
 func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 		return nil
 	}
@@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
 	return hi.CopyCache()
 }
 
-// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
+// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
-func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
+func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	if pending {
 		hl = c.f.handshakeManager
@@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 		hl = c.f.hostMap
 	}
 
-	h := hl.QueryVpnIp(vpnIp)
+	h := hl.QueryVpnAddr(vpnAddr)
 	if h == nil {
 		return nil
 	}
@@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 		return nil
 	}
@@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 		return false
 	}
@@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
 // the int returned is a count of tunnels closed
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
-	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
-	lighthouses := c.f.lightHouse.GetLighthouses()
-
 	shutdown := func(h *HostInfo) {
-		if excludeLighthouses {
-			if _, ok := lighthouses[h.vpnIp]; ok {
-				return
-			}
+		if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
+			return
 		}
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.closeTunnel(h)
 
-		c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
+		c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
 			Debug("Sending close tunnel message")
 		closed++
 	}
@@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
-		relayingHosts[relayingHost.vpnIp] = relayingHost
+		relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
 	}
 	c.f.hostMap.Unlock()
 
@@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Hosts map
 	c.f.hostMap.Lock()
 	for _, relayHost := range c.f.hostMap.Indexes {
-		if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
+		if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
 			hostInfos = append(hostInfos, relayHost)
 		}
 	}
@@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device {
 }
 
 func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
-
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp,
+		VpnAddrs:               make([]netip.Addr, len(h.vpnAddrs)),
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
@@ -285,12 +282,16 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 		CurrentRemote:          h.remote,
 	}
 
+	for i, a := range h.vpnAddrs {
+		chi.VpnAddrs[i] = a
+	}
+
 	if h.ConnectionState != nil {
 		chi.MessageCounter = h.ConnectionState.messageCounter.Load()
 	}
 
 	if c := h.GetCert(); c != nil {
-		chi.Cert = c.Copy()
+		chi.Cert = c.Certificate.Copy()
 	}
 
 	return chi
@@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 	hosts := make([]ControlHostInfo, 0)
 	pr := hl.GetPreferredRanges()
-	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+	hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 	})
 	return hosts

+ 22 - 36
control_test.go

@@ -5,7 +5,6 @@ import (
 	"net/netip"
 	"reflect"
 	"testing"
-	"time"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
@@ -14,10 +13,13 @@ import (
 )
 
 func TestControl_GetHostInfoByVpnIp(t *testing.T) {
+	//TODO: CERT-V2 with multiple certificate versions we have a problem with this test
+	// Some certs versions have different characteristics and each version implements their own Copy() func
+	// which means this is not a good place to test for exposing memory
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, netip.Prefix{})
+	hm := newHostMap(l)
 	hm.preferredRanges.Store(&[]netip.Prefix{})
 
 	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@@ -33,42 +35,27 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 
-	crt := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test",
-			Ips:            []*net.IPNet{&ipNet},
-			Subnets:        []*net.IPNet{},
-			Groups:         []string{"default-group"},
-			NotBefore:      time.Unix(1, 0),
-			NotAfter:       time.Unix(2, 0),
-			PublicKey:      []byte{5, 6, 7, 8},
-			IsCA:           false,
-			Issuer:         "the-issuer",
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-		},
-		Signature: []byte{1, 2, 1, 2, 1, 3},
-	}
-
-	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
-	remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+	remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
 
 	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
 	assert.True(t, ok)
 
+	crt := &dummyCert{}
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
-			peerCert: crt,
+			peerCert: &cert.CachedCertificate{Certificate: crt},
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
@@ -83,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         vpnIp2,
+		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
@@ -98,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 	}
 
-	thi := c.GetHostInfoByVpnIp(vpnIp, false)
+	thi := c.GetHostInfoByVpnAddr(vpnIp, false)
 
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  vpnIp,
+		VpnAddrs:               []netip.Addr{vpnIp},
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
@@ -113,14 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	assert.EqualValues(t, &expectedInfo, thi)
-	//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
-	//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
+	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(vpnIp2, false)
+		thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
 	})
 }
 

+ 43 - 22
control_tester.go

@@ -6,8 +6,6 @@ package nebula
 import (
 	"net/netip"
 
-	"github.com/slackhq/nebula/cert"
-
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
@@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
 	if toAddr.Addr().Is4() {
-		remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
 	} else {
-		remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
 	}
 }
 
@@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
+	remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
 }
 
 // GetFromTun will pull a packet off the tun side of nebula
@@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
-	//TODO: IPV6-WORK
-	ip := layers.IPv4{
-		Version:  4,
-		TTL:      64,
-		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().Addr().Unmap().AsSlice(),
-		DstIP:    toIp.Unmap().AsSlice(),
+func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
+	serialize := make([]gopacket.SerializableLayer, 0)
+	var netLayer gopacket.NetworkLayer
+	if toAddr.Is6() {
+		if !fromAddr.Is6() {
+			panic("Cant send ipv6 to ipv4")
+		}
+		ip := &layers.IPv6{
+			Version:    6,
+			NextHeader: layers.IPProtocolUDP,
+			SrcIP:      fromAddr.Unmap().AsSlice(),
+			DstIP:      toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
+	} else {
+		if !fromAddr.Is4() {
+			panic("Cant send ipv4 to ipv6")
+		}
+
+		ip := &layers.IPv4{
+			Version:  4,
+			TTL:      64,
+			Protocol: layers.IPProtocolUDP,
+			SrcIP:    fromAddr.Unmap().AsSlice(),
+			DstIP:    toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
 	}
 
 	udp := layers.UDP{
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 	}
-	err := udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(netLayer)
 	if err != nil {
 		panic(err)
 	}
@@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 		ComputeChecksums: true,
 		FixLengths:       true,
 	}
-	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+
+	serialize = append(serialize, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, serialize...)
 	if err != nil {
 		panic(err)
 	}
@@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 
-func (c *Control) GetVpnIp() netip.Addr {
-	return c.f.myVpnNet.Addr()
+func (c *Control) GetVpnAddrs() []netip.Addr {
+	return c.f.myVpnAddrs
 }
 
 func (c *Control) GetUDPAddr() netip.AddrPort {
@@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
 }
 
 func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
+	hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
 	if hostinfo == nil {
 		return false
 	}
@@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 }
 
-func (c *Control) GetCert() *cert.NebulaCertificate {
-	return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertState() *CertState {
+	return c.f.pki.getCertState()
 }
 
 func (c *Control) ReHandshake(vpnIp netip.Addr) {

+ 79 - 39
dns_server.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 	"sync"
 
+	"github.com/gaissmai/bart"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -21,24 +22,39 @@ var dnsAddr string
 
 type dnsRecords struct {
 	sync.RWMutex
-	dnsMap  map[string]string
-	hostMap *HostMap
+	l               *logrus.Logger
+	dnsMap4         map[string]netip.Addr
+	dnsMap6         map[string]netip.Addr
+	hostMap         *HostMap
+	myVpnAddrsTable *bart.Table[struct{}]
 }
 
-func newDnsRecords(hostMap *HostMap) *dnsRecords {
+func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
 	return &dnsRecords{
-		dnsMap:  make(map[string]string),
-		hostMap: hostMap,
+		l:               l,
+		dnsMap4:         make(map[string]netip.Addr),
+		dnsMap6:         make(map[string]netip.Addr),
+		hostMap:         hostMap,
+		myVpnAddrsTable: cs.myVpnAddrsTable,
 	}
 }
 
-func (d *dnsRecords) Query(data string) string {
+func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
+	data = strings.ToLower(data)
 	d.RLock()
 	defer d.RUnlock()
-	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
-		return r
+	switch q {
+	case dns.TypeA:
+		if r, ok := d.dnsMap4[data]; ok {
+			return r
+		}
+	case dns.TypeAAAA:
+		if r, ok := d.dnsMap6[data]; ok {
+			return r
+		}
 	}
-	return ""
+
+	return netip.Addr{}
 }
 
 func (d *dnsRecords) QueryCert(data string) string {
@@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 	}
 
-	hostinfo := d.hostMap.QueryVpnIp(ip)
+	hostinfo := d.hostMap.QueryVpnAddr(ip)
 	if hostinfo == nil {
 		return ""
 	}
@@ -57,43 +73,69 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 	}
 
-	cert := q.Details
-	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
-	return c
+	b, err := q.Certificate.MarshalJSON()
+	if err != nil {
+		return ""
+	}
+	return string(b)
 }
 
-func (d *dnsRecords) Add(host, data string) {
+// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
+func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
+	host = strings.ToLower(host)
 	d.Lock()
 	defer d.Unlock()
-	d.dnsMap[strings.ToLower(host)] = data
+	haveV4 := false
+	haveV6 := false
+	for _, addr := range addresses {
+		if addr.Is4() && !haveV4 {
+			d.dnsMap4[host] = addr
+			haveV4 = true
+		} else if addr.Is6() && !haveV6 {
+			d.dnsMap6[host] = addr
+			haveV6 = true
+		}
+		if haveV4 && haveV6 {
+			break
+		}
+	}
 }
 
-func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
+func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
+	a, _, _ := net.SplitHostPort(addr)
+	b, err := netip.ParseAddr(a)
+	if err != nil {
+		return false
+	}
+
+	if b.IsLoopback() {
+		return true
+	}
+
+	_, found := d.myVpnAddrsTable.Lookup(b)
+	return found //if we found it in this table, it's good
+}
+
+func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 	for _, q := range m.Question {
 		switch q.Qtype {
-		case dns.TypeA:
-			l.Debugf("Query for A %s", q.Name)
-			ip := dnsR.Query(q.Name)
-			if ip != "" {
-				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
+		case dns.TypeA, dns.TypeAAAA:
+			qType := dns.TypeToString[q.Qtype]
+			d.l.Debugf("Query for %s %s", qType, q.Name)
+			ip := d.Query(q.Qtype, q.Name)
+			if ip.IsValid() {
+				rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
 				if err == nil {
 					m.Answer = append(m.Answer, rr)
 				}
 			}
 		case dns.TypeTXT:
-			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b, err := netip.ParseAddr(a)
-			if err != nil {
+			// We only answer these queries from nebula nodes or localhost
+			if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
 				return
 			}
-
-			// We don't answer these queries from non nebula nodes or localhost
-			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
-			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
-				return
-			}
-			l.Debugf("Query for TXT %s", q.Name)
-			ip := dnsR.QueryCert(q.Name)
+			d.l.Debugf("Query for TXT %s", q.Name)
+			ip := d.QueryCert(q.Name)
 			if ip != "" {
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				if err == nil {
@@ -108,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 	}
 }
 
-func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
+func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.Compress = false
 
 	switch r.Opcode {
 	case dns.OpcodeQuery:
-		parseQuery(l, m, w)
+		d.parseQuery(m, w)
 	}
 
 	w.WriteMsg(m)
 }
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
-	dnsR = newDnsRecords(hostMap)
+func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
+	dnsR = newDnsRecords(l, cs, hostMap)
 
 	// attach request handler func
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		handleDnsRequest(l, w, r)
-	})
+	dns.HandleFunc(".", dnsR.handleDnsRequest)
 
 	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)

+ 20 - 5
dns_server_test.go

@@ -1,23 +1,38 @@
 package nebula
 
 import (
+	"net/netip"
 	"testing"
 
 	"github.com/miekg/dns"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestParsequery(t *testing.T) {
-	//TODO: This test is basically pointless
+	l := logrus.New()
 	hostMap := &HostMap{}
-	ds := newDnsRecords(hostMap)
-	ds.Add("test.com.com", "1.2.3.4")
+	ds := newDnsRecords(l, &CertState{}, hostMap)
+	addrs := []netip.Addr{
+		netip.MustParseAddr("1.2.3.4"),
+		netip.MustParseAddr("1.2.3.5"),
+		netip.MustParseAddr("fd01::24"),
+		netip.MustParseAddr("fd01::25"),
+	}
+	ds.Add("test.com.com", addrs)
 
-	m := new(dns.Msg)
+	m := &dns.Msg{}
 	m.SetQuestion("test.com.com", dns.TypeA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
 
-	//parseQuery(m)
+	m = &dns.Msg{}
+	m.SetQuestion("test.com.com", dns.TypeAAAA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
 }
 
 func Test_getDnsServerAddr(t *testing.T) {

File diff suppressed because it is too large
+ 389 - 168
e2e/handshakes_test.go


+ 0 - 125
e2e/helpers.go

@@ -1,125 +0,0 @@
-package e2e
-
-import (
-	"crypto/rand"
-	"io"
-	"net"
-	"net/netip"
-	"time"
-
-	"github.com/slackhq/nebula/cert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-)
-
-// NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "test ca",
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           true,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	if len(ips) > 0 {
-		nc.Details.Ips = make([]*net.IPNet, len(ips))
-		for i, ip := range ips {
-			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
-		}
-	}
-
-	if len(subnets) > 0 {
-		nc.Details.Subnets = make([]*net.IPNet, len(subnets))
-		for i, ip := range subnets {
-			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
-		}
-	}
-
-	if len(groups) > 0 {
-		nc.Details.Groups = groups
-	}
-
-	err = nc.Sign(cert.Curve_CURVE25519, priv)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, priv, pem
-}
-
-// NewTestCert will generate a signed certificate with the provided details.
-// Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
-	issuer, err := ca.Sha256Sum()
-	if err != nil {
-		panic(err)
-	}
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	pub, rawPriv := x25519Keypair()
-	ipb := ip.Addr().AsSlice()
-	nc := &cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name: name,
-			Ips:  []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
-			//Subnets:        subnets,
-			Groups:         groups,
-			NotBefore:      time.Unix(before.Unix(), 0),
-			NotAfter:       time.Unix(after.Unix(), 0),
-			PublicKey:      pub,
-			IsCA:           false,
-			Issuer:         issuer,
-			InvertedGroups: make(map[string]struct{}),
-		},
-	}
-
-	err = nc.Sign(ca.Details.Curve, key)
-	if err != nil {
-		panic(err)
-	}
-
-	pem, err := nc.MarshalToPEM()
-	if err != nil {
-		panic(err)
-	}
-
-	return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}

+ 75 - 38
e2e/helpers_test.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net/netip"
 	"os"
+	"strings"
 	"testing"
 	"time"
 
@@ -17,6 +18,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/stretchr/testify/assert"
@@ -26,27 +28,37 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
+func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 
-	vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
-	if err != nil {
-		panic(err)
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
 	}
 
 	var udpAddr netip.AddrPort
-	if vpnIpNet.Addr().Is4() {
-		budpIp := vpnIpNet.Addr().As4()
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
 		budpIp[1] -= 128
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
 	} else {
-		budpIp := vpnIpNet.Addr().As16()
-		budpIp[13] -= 128
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
-	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
-	caB, err := caCrt.MarshalToPEM()
+	caB, err := caCrt.MarshalPEM()
 	if err != nil {
 		panic(err)
 	}
@@ -88,11 +100,16 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
 	}
 
 	if overrides != nil {
-		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		final := m{}
+		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
 		if err != nil {
 			panic(err)
 		}
-		mc = overrides
+		mc = final
 	}
 
 	cb, err := yaml.Marshal(mc)
@@ -109,7 +126,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
 		panic(err)
 	}
 
-	return control, vpnIpNet, udpAddr, c
+	return control, vpnNetworks, udpAddr, c
 }
 
 type doneCb func()
@@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
-	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
+	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 	// And once more from me to them
-	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
+	controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 
-func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
+	//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
+	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
-	hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
+	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
+	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
+	assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
 
 	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
 	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
 	// Check that our indexes match
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
-
-	//TODO: Would be nice to assert this memory
-	//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
-	//	hBbyIndex := hmA.Indexes[hBinA.localIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
-	//	assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
-	//
-	//	//TODO: remote indexes are susceptible to collision
-	//	hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
-	//	assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
-	//}
-	//
-	//// Check hostmap indexes too
-	//checkIndexes("hmA", hmA, hBinA)
-	//checkIndexes("hmB", hmB, hAinB)
 }
 
 func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	if toIp.Is6() {
+		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
+	} else {
+		assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
+	}
+}
+
+func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
+	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+	assert.NotNil(t, v6, "No ipv6 data found")
+
+	assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
+
+	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	assert.NotNil(t, udp, "No udp data found")
+
+	assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
+	assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
+
+	data := packet.ApplicationLayer()
+	assert.NotNil(t, data)
+	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
+}
+
+func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
@@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 
+func getAddrs(ns []netip.Prefix) []netip.Addr {
+	var a []netip.Addr
+	for _, n := range ns {
+		a = append(a, n.Addr())
+	}
+	return a
+}
+
 func NewTestLogger() *logrus.Logger {
 	l := logrus.New()
 

+ 5 - 4
e2e/router/hostmap.go

@@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	var lines []string
 	var globalLines []*edge
 
-	clusterName := strings.Trim(c.GetCert().Details.Name, " ")
-	clusterVpnIp := c.GetCert().Details.Ips[0].IP
+	crt := c.GetCertState().GetDefaultCertificate()
+	clusterName := strings.Trim(crt.Name(), " ")
+	clusterVpnIp := crt.Networks()[0].Addr()
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 
 	hm := c.GetHostmap()
@@ -101,8 +102,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	for _, idx := range indexes {
 		hi, ok := hm.Indexes[idx]
 		if ok {
-			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
-			remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
+			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
+			remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			_ = hi
 		}

+ 44 - 22
e2e/router/router.go

@@ -10,8 +10,8 @@ import (
 	"os"
 	"path/filepath"
 	"reflect"
+	"regexp"
 	"sort"
-	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 			panic("Duplicate listen address: " + addr.String())
 		}
 
-		r.vpnControls[c.GetVpnIp()] = c
+		for _, vpnAddr := range c.GetVpnAddrs() {
+			r.vpnControls[vpnAddr] = c
+		}
+
 		r.controls[addr] = c
 	}
 
@@ -213,11 +216,11 @@ func (r *R) renderFlow() {
 			continue
 		}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr.String(), ":", "-", 1)
+		sanAddr := normalizeName(addr.String())
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
-			sanAddr, e.packet.from.GetVpnIp(), sanAddr,
+			sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
 		)
 	}
 
@@ -250,9 +253,9 @@ func (r *R) renderFlow() {
 
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.from.GetUDPAddr().String()),
 				line,
-				strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.to.GetUDPAddr().String()),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 		}
@@ -267,6 +270,11 @@ func (r *R) renderFlow() {
 	}
 }
 
+func normalizeName(s string) string {
+	rx := regexp.MustCompile("[\\[\\]\\:]")
+	return rx.ReplaceAllLiteralString(s, "_")
+}
+
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
+		return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
 	})
 
 	s := renderHostmaps(c...)
@@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 		// Nope, lets push the sender along
 		case p := <-udpTx:
 			r.Lock()
-			c := r.getControl(sender.GetUDPAddr(), p.To, p)
+			a := sender.GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 				r.Unlock()
-				panic("No control for udp tx")
+				panic("No control for udp tx " + a.String())
 			}
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			c.InjectUDPPacket(p)
@@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
-			c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
+			a := cm[x].GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 				r.Unlock()
-				panic("No control for udp tx")
+				panic(fmt.Sprintf("No control for udp tx %s", p.To))
 			}
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			c.InjectUDPPacket(p)
@@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
 }
 
 func (r *R) formatUdpPacket(p *packet) string {
-	packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
-	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
-	if v4 == nil {
-		panic("not an ipv4 packet")
+	var packet gopacket.Packet
+	var srcAddr netip.Addr
+
+	packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
+	if packet.ErrorLayer() == nil {
+		v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
+	} else {
+		packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
+		v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
 	}
 
 	from := "unknown"
-	srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
 	if c, ok := r.vpnControls[srcAddr]; ok {
 		from = c.GetUDPAddr().String()
 	}
 
-	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
-	if udp == nil {
+	udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	if udpLayer == nil {
 		panic("not a udp packet")
 	}
 
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
-		udp.SrcPort,
-		udp.DstPort,
+		normalizeName(from),
+		normalizeName(p.to.GetUDPAddr().String()),
+		udpLayer.SrcPort,
+		udpLayer.DstPort,
 		string(data.Payload()),
 	)
 }

+ 12 - 5
examples/config.yml

@@ -13,6 +13,12 @@ pki:
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   #disconnect_invalid: true
 
+  # default_version controls which certificate version is used in handshakes.
+  # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
+  # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
+  # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
+  # default_version: 1
+
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # The syntax is:
@@ -285,7 +291,6 @@ tun:
     # send multiport handshakes.
     #tx_handshake_delay: 2
 
-# TODO
 # Configure logging level
 logging:
   # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
@@ -377,10 +382,12 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
-  #      Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
-  #      if `default_local_cidr_any` is false, otherwise its `any`.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
+  #     If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
+  #     Otherwise the default is any vpn network assigned to via the certificate.
+  #     `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
+  #     If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
 

+ 96 - 94
firewall.go

@@ -22,7 +22,7 @@ import (
 )
 
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
 }
 
 type conn struct {
@@ -51,10 +51,13 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 
-	// Used to ensure we don't emit local packets for ips we don't own
-	localIps     *bart.Table[struct{}]
-	assignedCIDR netip.Prefix
-	hasSubnets   bool
+	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
+	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
+	routableNetworks *bart.Table[struct{}]
+
+	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
+	assignedNetworks  []netip.Prefix
+	hasUnsafeNetworks bool
 
 	rules        string
 	rulesVersion uint16
@@ -67,9 +70,9 @@ type Firewall struct {
 }
 
 type firewallMetrics struct {
-	droppedLocalIP  metrics.Counter
-	droppedRemoteIP metrics.Counter
-	droppedNoRule   metrics.Counter
+	droppedLocalAddr  metrics.Counter
+	droppedRemoteAddr metrics.Counter
+	droppedNoRule     metrics.Counter
 }
 
 type FirewallConntrack struct {
@@ -126,88 +129,87 @@ type firewallLocalCIDR struct {
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
-func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
+// The certificate provided should be the highest version loaded in memory.
+func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
 	//TODO: error on 0 duration
-	var min, max time.Duration
+	var tmin, tmax time.Duration
 
 	if tcpTimeout < UDPTimeout {
-		min = tcpTimeout
-		max = UDPTimeout
+		tmin = tcpTimeout
+		tmax = UDPTimeout
 	} else {
-		min = UDPTimeout
-		max = tcpTimeout
+		tmin = UDPTimeout
+		tmax = tcpTimeout
 	}
 
-	if defaultTimeout < min {
-		min = defaultTimeout
-	} else if defaultTimeout > max {
-		max = defaultTimeout
+	if defaultTimeout < tmin {
+		tmin = defaultTimeout
+	} else if defaultTimeout > tmax {
+		tmax = defaultTimeout
 	}
 
-	localIps := new(bart.Table[struct{}])
-	var assignedCIDR netip.Prefix
-	var assignedSet bool
-	for _, ip := range c.Details.Ips {
-		//TODO: IPV6-WORK the unmap is a bit unfortunate
-		nip, _ := netip.AddrFromSlice(ip.IP)
-		nip = nip.Unmap()
-		nprefix := netip.PrefixFrom(nip, nip.BitLen())
-		localIps.Insert(nprefix, struct{}{})
-
-		if !assignedSet {
-			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = nprefix
-			assignedSet = true
-		}
+	routableNetworks := new(bart.Table[struct{}])
+	var assignedNetworks []netip.Prefix
+	for _, network := range c.Networks() {
+		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
+		routableNetworks.Insert(nprefix, struct{}{})
+		assignedNetworks = append(assignedNetworks, network)
 	}
 
-	for _, n := range c.Details.Subnets {
-		nip, _ := netip.AddrFromSlice(n.IP)
-		ones, _ := n.Mask.Size()
-		nip = nip.Unmap()
-		localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
+	hasUnsafeNetworks := false
+	for _, n := range c.UnsafeNetworks() {
+		routableNetworks.Insert(n, struct{}{})
+		hasUnsafeNetworks = true
 	}
 
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 			Conns:      make(map[firewall.Packet]*conn),
-			TimerWheel: NewTimerWheel[firewall.Packet](min, max),
+			TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
 		},
-		InRules:        newFirewallTable(),
-		OutRules:       newFirewallTable(),
-		TCPTimeout:     tcpTimeout,
-		UDPTimeout:     UDPTimeout,
-		DefaultTimeout: defaultTimeout,
-		localIps:       localIps,
-		assignedCIDR:   assignedCIDR,
-		hasSubnets:     len(c.Details.Subnets) > 0,
-		l:              l,
+		InRules:           newFirewallTable(),
+		OutRules:          newFirewallTable(),
+		TCPTimeout:        tcpTimeout,
+		UDPTimeout:        UDPTimeout,
+		DefaultTimeout:    defaultTimeout,
+		routableNetworks:  routableNetworks,
+		assignedNetworks:  assignedNetworks,
+		hasUnsafeNetworks: hasUnsafeNetworks,
+		l:                 l,
 
 		incomingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
 		},
 		outgoingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
 		},
 	}
 }
 
-func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
+	certificate := cs.getCertificate(cert.Version2)
+	if certificate == nil {
+		certificate = cs.getCertificate(cert.Version1)
+	}
+
+	if certificate == nil {
+		panic("No certificate available to reconfigure the firewall")
+	}
+
 	fw := NewFirewall(
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
-		nc,
+		certificate,
 		//TODO: max_connections
 	)
 
-	//TODO: Flip to false after v1.9 release
-	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
+	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
 
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	switch inboundAction {
@@ -287,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		fp = ft.TCP
 	case firewall.ProtoUDP:
 		fp = ft.UDP
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		fp = ft.ICMP
 	case firewall.ProtoAny:
 		fp = ft.AnyProto
@@ -421,33 +423,31 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) error {
+func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(fp, h, caPool, localCache) {
 		return nil
 	}
 
 	// Make sure remote address matches nebula certificate
-	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-		_, ok := remoteCidr.Lookup(fp.RemoteIP)
+	if h.networks != nil {
+		_, ok := h.networks.Lookup(fp.RemoteAddr)
 		if !ok {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.vpnIp {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-	_, ok := f.localIps.Lookup(fp.LocalIP)
+	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
 	if !ok {
-		f.metrics(incoming).droppedLocalIP.Inc(1)
+		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 	}
 
@@ -492,7 +492,7 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 
-func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.NebulaCAPool, localCache firewall.ConntrackCache) bool {
+func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
 	if localCache != nil {
 		if _, ok := localCache[fp]; ok {
 			return true
@@ -619,7 +619,7 @@ func (f *Firewall) evict(p firewall.Packet) {
 	delete(conntrack.Conns, p)
 }
 
-func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	if ft.AnyProto.match(p, incoming, c, caPool) {
 		return true
 	}
@@ -633,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 		}
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 		}
@@ -663,7 +663,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
 	return nil
 }
 
-func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	// We don't have any allowed ports, bail
 	if fp == nil {
 		return false
@@ -726,7 +726,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
 	return nil
 }
 
-func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
 	if fc == nil {
 		return false
 	}
@@ -735,18 +735,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 		return true
 	}
 
-	if t, ok := fc.CAShas[c.Details.Issuer]; ok {
+	if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok {
 		if t.match(p, c) {
 			return true
 		}
 	}
 
-	s, err := caPool.GetCAForCert(c)
+	s, err := caPool.GetCAForCert(c.Certificate)
 	if err != nil {
 		return false
 	}
 
-	return fc.CANames[s.Details.Name].match(p, c)
+	return fc.CANames[s.Certificate.Name()].match(p, c)
 }
 
 func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
@@ -826,7 +826,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
 	return false
 }
 
-func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool {
 	if fr == nil {
 		return false
 	}
@@ -841,7 +841,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		found := false
 
 		for _, g := range sg.Groups {
-			if _, ok := c.Details.InvertedGroups[g]; !ok {
+			if _, ok := c.InvertedGroups[g]; !ok {
 				found = false
 				break
 			}
@@ -855,42 +855,44 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 	}
 
 	if fr.Hosts != nil {
-		if flc, ok := fr.Hosts[c.Details.Name]; ok {
+		if flc, ok := fr.Hosts[c.Certificate.Name()]; ok {
 			if flc.match(p, c) {
 				return true
 			}
 		}
 	}
 
-	matched := false
-	prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
-	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
-		if prefix.Contains(p.RemoteIP) && val.match(p, c) {
-			matched = true
-			return false
+	for _, v := range fr.CIDR.Supernets(netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())) {
+		if v.match(p, c) {
+			return true
 		}
-		return true
-	})
-	return matched
+	}
+
+	return false
 }
 
 func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 	if !localIp.IsValid() {
-		if !f.hasSubnets || f.defaultLocalCIDRAny {
+		if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
 			flc.Any = true
 			return nil
 		}
 
-		localIp = f.assignedCIDR
+		for _, network := range f.assignedNetworks {
+			flc.LocalCIDR.Insert(network, struct{}{})
+		}
+		return nil
+
 	} else if localIp.Bits() == 0 {
 		flc.Any = true
+		return nil
 	}
 
 	flc.LocalCIDR.Insert(localIp, struct{}{})
 	return nil
 }
 
-func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool {
 	if flc == nil {
 		return false
 	}
@@ -899,7 +901,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
 		return true
 	}
 
-	_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
 	return ok
 }
 

+ 11 - 10
firewall/packet.go

@@ -10,18 +10,19 @@ import (
 type m map[string]interface{}
 
 const (
-	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	ProtoTCP  = 6
-	ProtoUDP  = 17
-	ProtoICMP = 1
+	ProtoAny    = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP    = 6
+	ProtoUDP    = 17
+	ProtoICMP   = 1
+	ProtoICMPv6 = 58
 
 	PortAny      = 0  // Special value for matching `port: any`
 	PortFragment = -1 // Special value for matching `port: fragment`
 )
 
 type Packet struct {
-	LocalIP    netip.Addr
-	RemoteIP   netip.Addr
+	LocalAddr  netip.Addr
+	RemoteAddr netip.Addr
 	LocalPort  uint16
 	RemotePort uint16
 	Protocol   uint8
@@ -30,8 +31,8 @@ type Packet struct {
 
 func (fp *Packet) Copy() *Packet {
 	return &Packet{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
+		LocalAddr:  fp.LocalAddr,
+		RemoteAddr: fp.RemoteAddr,
 		LocalPort:  fp.LocalPort,
 		RemotePort: fp.RemotePort,
 		Protocol:   fp.Protocol,
@@ -52,8 +53,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 	}
 	return json.Marshal(m{
-		"LocalIP":    fp.LocalIP.String(),
-		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalAddr":  fp.LocalAddr.String(),
+		"RemoteAddr": fp.RemoteAddr.String(),
 		"LocalPort":  fp.LocalPort,
 		"RemotePort": fp.RemotePort,
 		"Protocol":   proto,

+ 136 - 195
firewall_test.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"errors"
 	"math"
-	"net"
 	"net/netip"
 	"testing"
 	"time"
@@ -14,11 +13,12 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestNewFirewall(t *testing.T) {
 	l := test.NewLogger()
-	c := &cert.NebulaCertificate{}
+	c := &dummyCert{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	conntrack := fw.Conntrack
 	assert.NotNil(t, conntrack)
@@ -60,7 +60,7 @@ func TestFirewall_AddRule(t *testing.T) {
 	ob := &bytes.Buffer{}
 	l.SetOutput(ob)
 
-	c := &cert.NebulaCertificate{}
+	c := &dummyCert{}
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
@@ -129,35 +129,30 @@ func TestFirewall_Drop(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 	}
 
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			Groups:         []string{"default-group"},
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-			Issuer:         "signer-shasum",
-		},
+	c := dummyCert{
+		name:     "host1",
+		networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
 	}
 	h := HostInfo{
 		ConnectionState: &ConnectionState{
-			peerCert: &c,
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
 		},
-		vpnIp: netip.MustParseAddr("1.2.3.4"),
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.networks, c.unsafeNetworks)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -172,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	// test remote mismatch
-	oldRemote := p.RemoteIP
-	p.RemoteIP = netip.MustParseAddr("1.2.3.10")
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
-	p.RemoteIP = oldRemote
+	p.RemoteAddr = oldRemote
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@@ -190,14 +185,14 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
-	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
-	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
@@ -217,7 +212,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 
 	b.Run("fail on proto", func(b *testing.B) {
 		// This benchmark is showing us the cost of failing to match the protocol
-		c := &cert.NebulaCertificate{}
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
 		for n := 0; n < b.N; n++ {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
 		}
@@ -225,28 +222,31 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 
 	b.Run("pass proto, fail on port", func(b *testing.B) {
 		// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
-		c := &cert.NebulaCertificate{}
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
 		for n := 0; n < b.N; n++ {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
 		}
 	})
 
 	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
-		c := &cert.NebulaCertificate{}
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
 		ip := netip.MustParsePrefix("9.254.254.254/32")
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 	})
 
 	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
-		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "nope",
-				Ips:            []*net.IPNet{ip},
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
 			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		for n := 0; n < b.N; n++ {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
@@ -254,25 +254,24 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	})
 
 	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
-		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "nope",
-				Ips:            []*net.IPNet{ip},
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
 			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
 
 	b.Run("pass on group on any local cidr", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"good-group": {}},
-				Name:           "nope",
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
 			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
 		}
 		for n := 0; n < b.N; n++ {
 			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
@@ -280,82 +279,28 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	})
 
 	b.Run("pass on group on specific local cidr", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"good-group": {}},
-				Name:           "nope",
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
 			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
 		}
 		for n := 0; n < b.N; n++ {
-			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
 
 	b.Run("pass on name", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "good-host",
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "good-host",
 			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		for n := 0; n < b.N; n++ {
 			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 	})
-	//
-	//b.Run("pass on ip", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
-	//	}
-	//})
-	//
-	//b.Run("pass on local ip", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
-	//	}
-	//})
-	//
-	//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
-	//
-	//b.Run("pass on ip with any port", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
-	//	}
-	//})
-	//
-	//b.Run("pass on local ip with any port", func(b *testing.B) {
-	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
-	//	c := &cert.NebulaCertificate{
-	//		Details: cert.NebulaCertificateDetails{
-	//			InvertedGroups: map[string]struct{}{"nope": {}},
-	//			Name:           "good-host",
-	//		},
-	//	}
-	//	for n := 0; n < b.N; n++ {
-	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
-	//	}
-	//})
 }
 
 func TestFirewall_Drop2(t *testing.T) {
@@ -364,49 +309,47 @@ func TestFirewall_Drop2(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 	}
 
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
+	network := netip.MustParsePrefix("1.2.3.4/24")
 
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
 		},
+		InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
 	}
 	h := HostInfo{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
-	c1 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
+	c1 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
 		},
+		InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
 	}
 	h1 := HostInfo{
+		vpnAddrs: []netip.Addr{network.Addr()},
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
 	}
-	h1.CreateRemoteCIDR(&c1)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
-	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
@@ -423,72 +366,68 @@ func TestFirewall_Drop3(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  1,
 		RemotePort: 1,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 	}
 
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name: "host-owner",
-			Ips:  []*net.IPNet{&ipNet},
+	network := netip.MustParsePrefix("1.2.3.4/24")
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{network},
 		},
 	}
 
-	c1 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:   "host1",
-			Ips:    []*net.IPNet{&ipNet},
-			Issuer: "signer-sha-bad",
+	c1 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
+			issuer:   "signer-sha-bad",
 		},
 	}
 	h1 := HostInfo{
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h1.CreateRemoteCIDR(&c1)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
-	c2 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:   "host2",
-			Ips:    []*net.IPNet{&ipNet},
-			Issuer: "signer-sha",
+	c2 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host2",
+			networks: []netip.Prefix{network},
+			issuer:   "signer-sha",
 		},
 	}
 	h2 := HostInfo{
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h2.CreateRemoteCIDR(&c2)
+	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
 
-	c3 := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:   "host3",
-			Ips:    []*net.IPNet{&ipNet},
-			Issuer: "signer-sha-bad",
+	c3 := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host3",
+			networks: []netip.Prefix{network},
+			issuer:   "signer-sha-bad",
 		},
 	}
 	h3 := HostInfo{
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h3.CreateRemoteCIDR(&c3)
+	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
 
-	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
 	cp := cert.NewCAPool()
@@ -501,6 +440,11 @@ func TestFirewall_Drop3(t *testing.T) {
 	// c3 should fail because no match
 	resetConntrack(fw)
 	assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
+
+	// Test a remote address match
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
+	assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 }
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -509,37 +453,33 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Fragment:   false,
 	}
-
-	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
-		Mask: net.IPMask{255, 255, 255, 0},
-	}
-
-	c := cert.NebulaCertificate{
-		Details: cert.NebulaCertificateDetails{
-			Name:           "host1",
-			Ips:            []*net.IPNet{&ipNet},
-			Groups:         []string{"default-group"},
-			InvertedGroups: map[string]struct{}{"default-group": {}},
-			Issuer:         "signer-shasum",
+	network := netip.MustParsePrefix("1.2.3.4/24")
+
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host1",
+			networks: []netip.Prefix{network},
+			groups:   []string{"default-group"},
+			issuer:   "signer-shasum",
 		},
+		InvertedGroups: map[string]struct{}{"default-group": {}},
 	}
 	h := HostInfo{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
-	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
@@ -552,7 +492,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	oldFw := fw
-	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
@@ -561,7 +501,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	oldFw = fw
-	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
@@ -641,8 +581,6 @@ func BenchmarkLookup(b *testing.B) {
 			ml(m, a)
 		}
 	})
-
-	//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
 }
 
 func Test_parsePort(t *testing.T) {
@@ -688,56 +626,59 @@ func Test_parsePort(t *testing.T) {
 func TestNewFirewallFromConfig(t *testing.T) {
 	l := test.NewLogger()
 	// Test a bad rule definition
-	c := &cert.NebulaCertificate{}
+	c := &dummyCert{}
+	cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
+	require.NoError(t, err)
+
 	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
-	_, err := NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 	// Test both port and code
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 	// Test proto error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 	// Test cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 

+ 21 - 19
go.mod

@@ -1,52 +1,54 @@
 module github.com/slackhq/nebula
 
-go 1.22.0
+go 1.23.6
 
-toolchain go1.22.2
+toolchain go1.23.7
 
 require (
-	dario.cat/mergo v1.0.0
+	dario.cat/mergo v1.0.1
 	github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
 	github.com/armon/go-radix v1.0.0
 	github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432
 	github.com/flynn/noise v1.1.0
-	github.com/gaissmai/bart v0.11.1
+	github.com/gaissmai/bart v0.18.1
 	github.com/gogo/protobuf v1.3.2
 	github.com/google/gopacket v1.1.19
 	github.com/kardianos/service v1.2.2
-	github.com/miekg/dns v1.1.61
+	github.com/miekg/dns v1.1.62
+	github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b
 	github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f
-	github.com/prometheus/client_golang v1.19.1
+	github.com/prometheus/client_golang v1.20.4
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
-	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
+	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stretchr/testify v1.9.0
-	github.com/vishvananda/netlink v1.2.1-beta.2
-	golang.org/x/crypto v0.26.0
+	github.com/vishvananda/netlink v1.3.0
+	golang.org/x/crypto v0.36.0
 	golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
-	golang.org/x/net v0.28.0
-	golang.org/x/sync v0.8.0
-	golang.org/x/sys v0.24.0
-	golang.org/x/term v0.23.0
+	golang.org/x/net v0.37.0
+	golang.org/x/sync v0.12.0
+	golang.org/x/sys v0.31.0
+	golang.org/x/term v0.30.0
 	golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
 	golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b
 	golang.zx2c4.com/wireguard/windows v0.5.3
-	google.golang.org/protobuf v1.34.2
+	google.golang.org/protobuf v1.36.5
 	gopkg.in/yaml.v2 v2.4.0
 	gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe
 )
 
 require (
 	github.com/beorn7/perks v1.0.1 // indirect
-	github.com/bits-and-blooms/bitset v1.13.0 // indirect
-	github.com/cespare/xxhash/v2 v2.2.0 // indirect
+	github.com/cespare/xxhash/v2 v2.3.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/google/btree v1.1.2 // indirect
+	github.com/klauspost/compress v1.17.9 // indirect
+	github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
-	github.com/prometheus/client_model v0.5.0 // indirect
-	github.com/prometheus/common v0.48.0 // indirect
-	github.com/prometheus/procfs v0.12.0 // indirect
+	github.com/prometheus/client_model v0.6.1 // indirect
+	github.com/prometheus/common v0.55.0 // indirect
+	github.com/prometheus/procfs v0.15.1 // indirect
 	github.com/vishvananda/netns v0.0.4 // indirect
 	golang.org/x/mod v0.18.0 // indirect
 	golang.org/x/time v0.5.0 // indirect

+ 42 - 37
go.sum

@@ -1,6 +1,6 @@
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
-dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
-dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
+dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
+dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
 github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
 github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@@ -14,11 +14,9 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
-github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
-github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
-github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
+github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 h1:M5QgkYacWj0Xs8MhpIK/5uwU02icXpEoSo9sM2aRCps=
 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432/go.mod h1:xwIwAxMvYnVrGJPe2FKx5prTrnAjGOD8zvDOnxnrrkM=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -26,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
-github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
-github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
+github.com/gaissmai/bart v0.18.1 h1:bX2j560JC1MJpoEDevBGvXL5OZ1mkls320Vl8Igb5QQ=
+github.com/gaissmai/bart v0.18.1/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@@ -70,6 +68,8 @@ github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX
 github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
+github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
+github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
 github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@@ -80,13 +80,19 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
 github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
+github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
 github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
-github.com/miekg/dns v1.1.61 h1:nLxbwF3XxhwVSm8g9Dghm9MHPaUZuqhPiGL+675ZmEs=
-github.com/miekg/dns v1.1.61/go.mod h1:mnAarhS3nWaW+NVP2wTkYVIZyHNJ098SJZUki3eykwQ=
+github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
+github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk=
+github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
+github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
 github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
 github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f h1:8dM0ilqKL0Uzl42GABzzC4Oqlc3kGRILz0vgoff7nwg=
@@ -100,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP
 github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
 github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
 github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0=
-github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
-github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
+github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI=
+github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
 github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
-github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
+github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
+github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
 github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
 github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo=
 github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc=
-github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
-github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
+github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
+github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
 github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
 github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
 github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
-github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
-github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
+github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
+github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM=
 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
 github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -129,8 +135,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
+github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
@@ -139,9 +145,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
 github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
-github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs=
-github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
-github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
+github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk=
+github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs=
 github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
 github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -151,8 +156,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
-golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
-golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
+golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
+golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
 golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@@ -171,8 +176,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
-golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
+golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
+golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
 golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -180,30 +185,30 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
-golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
+golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
 golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg=
-golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
+golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
-golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
+golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
+golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -234,8 +239,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
-google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
+google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
+google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
 gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

+ 223 - 104
handshake_ix.go

@@ -2,10 +2,12 @@ package nebula
 
 import (
 	"net/netip"
+	"slices"
 	"time"
 
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 )
@@ -17,23 +19,59 @@ import (
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 	}
 
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
+	// If we're connecting to a v6 address we must use a v2 cert
+	cs := f.pki.getCertState()
+	v := cs.defaultVersion
+	for _, a := range hh.hostinfo.vpnAddrs {
+		if a.Is6() {
+			v = cert.Version2
+			break
+		}
+	}
+
+	crt := cs.getCertificate(v)
+	if crt == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate is available")
+		return false
+	}
+
+	crtHs := cs.getHandshakeBytes(v)
+	if crtHs == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Failed to create connection state")
+		return false
+	}
 	hh.hostinfo.ConnectionState = ci
 
-	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hh.hostinfo.localIndexId,
-		Time:           uint64(time.Now().UnixNano()),
-		Cert:           certState.RawCertificateNoKey,
+	hs := &NebulaHandshake{
+		Details: &NebulaHandshakeDetails{
+			InitiatorIndex: hh.hostinfo.localIndexId,
+			Time:           uint64(time.Now().UnixNano()),
+			Cert:           crtHs,
+			CertVersion:    uint32(v),
+		},
 	}
 
 	if f.multiPort.Tx || f.multiPort.Rx {
-		hsProto.InitiatorMultiPort = &MultiPortDetails{
+		hs.Details.InitiatorMultiPort = &MultiPortDetails{
 			RxSupported: f.multiPort.Rx,
 			TxSupported: f.multiPort.Tx,
 			BasePort:    uint32(f.multiPort.TxBasePort),
@@ -41,15 +79,9 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 		}
 	}
 
-	hsBytes := []byte{}
-
-	hs := &NebulaHandshake{
-		Details: hsProto,
-	}
-	hsBytes, err = hs.Marshal()
-
+	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 	}
@@ -58,7 +90,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 	}
@@ -73,30 +105,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 }
 
 func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
+	cs := f.pki.getCertState()
+	crt := cs.GetDefaultCertificate()
+	if crt == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", cs.defaultVersion).
+			Error("Unable to handshake with host because no certificate is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to create connection state")
+		return
+	}
+
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to call noise.ReadMessage")
 		return
 	}
 
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
-	/*
-		l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
-	*/
 	if err != nil || hs.Details == nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed unmarshal handshake message")
 		return
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@@ -109,8 +155,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 	}
 
-	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
-	if !ok {
+	if remoteCert.Certificate.Version() != ci.myCert.Version() {
+		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
+		rc := cs.getCertificate(remoteCert.Certificate.Version())
+		if rc == nil {
+			f.l.WithError(err).WithField("udpAddr", addr).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
+				Info("Unable to handshake with host due to missing certificate version")
+			return
+		}
+
+		// Record the certificate we are actually using
+		ci.myCert = rc
+	}
+
+	if len(remoteCert.Certificate.Networks()) == 0 {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
 
@@ -122,30 +181,54 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 	}
 
-	vpnIp = vpnIp.Unmap()
-	certName := remoteCert.Details.Name
-	fingerprint, _ := remoteCert.Sha256Sum()
-	issuer := remoteCert.Details.Issuer
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	certName := remoteCert.Certificate.Name()
+	fingerprint := remoteCert.Fingerprint
+	issuer := remoteCert.Certificate.Issuer()
+
+	for _, network := range remoteCert.Certificate.Networks() {
+		vpnAddr := network.Addr()
+		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
+		if found {
+			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("fingerprint", fingerprint).
+				WithField("issuer", issuer).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			return
+		}
+
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
 
-	if vpnIp == f.myVpnNet.Addr() {
-		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
 		return
 	}
 
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		// addr can be invalid when the tunnel is being relayed.
+		// We only want to apply the remote allow list for direct tunnels here
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
 
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -177,19 +260,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		vpnIp:             vpnIp,
+		vpnAddrs:          vpnAddrs,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		multiportTx:       multiportTx,
 		multiportRx:       multiportRx,
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}
 
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
@@ -199,13 +282,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		Info("Handshake message received")
 
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = certState.RawCertificateNoKey
+	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
+	if hs.Details.Cert == nil {
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("certVersion", ci.myCert.Version()).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return
+	}
+
+	hs.Details.CertVersion = uint32(ci.myCert.Version())
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
 	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -216,14 +312,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -247,9 +343,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
-	hostinfo.CreateRemoteCIDR(remoteCert)
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -263,7 +359,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
-				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 
 			msg = existing.HandshakePacket[2]
@@ -278,11 +374,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 					err = f.outside.WriteTo(msg, addr)
 				}
 				if err != nil {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 				} else {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 				}
@@ -292,16 +388,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 				}
-				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
+				f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 				return
 			}
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@@ -312,23 +408,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				Info("Handshake too old")
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -351,7 +447,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			err = f.outside.WriteTo(msg, addr)
 		}
 		if err != nil {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -359,7 +455,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake")
 		} else {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -372,9 +468,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 		}
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+		// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
+		// it's correctly marked as working.
+		via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -401,8 +500,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	hostinfo := hh.hostinfo
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
 	}
@@ -410,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -419,7 +519,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 
@@ -431,7 +531,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -452,9 +552,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		)
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
-		e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 
 		if f.l.Level > logrus.DebugLevel {
@@ -467,8 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		return true
 	}
 
-	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
-	if !ok {
+	if len(remoteCert.Certificate.Networks()) == 0 {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 
@@ -476,18 +575,55 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			e = e.WithField("cert", remoteCert)
 		}
 
-		e.Info("Invalid vpn ip from host")
+		e.Info("Empty networks from host")
 		return true
 	}
 
-	vpnIp = vpnIp.Unmap()
-	certName := remoteCert.Details.Name
-	fingerprint, _ := remoteCert.Sha256Sum()
-	issuer := remoteCert.Details.Issuer
+	vpnNetworks := remoteCert.Certificate.Networks()
+	certName := remoteCert.Certificate.Name()
+	fingerprint := remoteCert.Fingerprint
+	issuer := remoteCert.Certificate.Issuer()
+
+	hostinfo.remoteIndexId = hs.Details.ResponderIndex
+	hostinfo.lastHandshakeTime = hs.Details.Time
+
+	// Store their cert and our symmetric keys
+	ci.peerCert = remoteCert
+	ci.dKey = NewNebulaCipherState(dKey)
+	ci.eKey = NewNebulaCipherState(eKey)
+
+	// Make sure the current udpAddr being used is set for responding
+	if addr.IsValid() {
+		hostinfo.SetRemote(addr)
+	} else {
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+	}
+
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	for _, network := range vpnNetworks {
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		vpnAddr := network.Addr()
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
+		return true
+	}
 
 	// Ensure the right host responded
-	if vpnIp != hostinfo.vpnIp {
-		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
+	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
@@ -496,16 +632,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
-			//TODO: this doesnt know if its being added or is being used for caching a packet
+		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes.BlockRemote(addr)
 
-			// Get the correct remote list for the host we did handshake with
-			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
-
-			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
+				WithField("vpnNetworks", vpnNetworks).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				Info("Blocked addresses for handshakes")
 
@@ -513,8 +646,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			newHH.packetStore = hh.packetStore
 			hh.packetStore = []*cachedPacket{}
 
-			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-			hostinfo.vpnIp = vpnIp
+			// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnAddrs = vpnAddrs
 			f.sendCloseTunnel(hostinfo)
 		})
 
@@ -525,7 +658,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
@@ -536,25 +669,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx).
 		Info("Handshake message received")
 
-	hostinfo.remoteIndexId = hs.Details.ResponderIndex
-	hostinfo.lastHandshakeTime = hs.Details.Time
-
-	// Store their cert and our symmetric keys
-	ci.peerCert = remoteCert
-	ci.dKey = NewNebulaCipherState(dKey)
-	ci.eKey = NewNebulaCipherState(eKey)
-
-	// Make sure the current udpAddr being used is set for responding
-	if addr.IsValid() {
-		hostinfo.SetRemote(addr)
-	} else {
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
-	}
-
 	// Build up the radix for the firewall if we have subnets in the cert
-	hostinfo.CreateRemoteCIDR(remoteCert)
+	hostinfo.vpnAddrs = vpnAddrs
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
-	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
+	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 

+ 181 - 124
handshake_manager.go

@@ -7,14 +7,15 @@ import (
 	"encoding/binary"
 	"errors"
 	"net/netip"
+	"slices"
 	"sync"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
-	"golang.org/x/exp/slices"
 )
 
 const (
@@ -121,18 +122,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context) {
-	clockSource := time.NewTicker(c.config.tryInterval)
+func (hm *HandshakeManager) Run(ctx context.Context) {
+	clockSource := time.NewTicker(hm.config.tryInterval)
 	defer clockSource.Stop()
 
 	for {
 		select {
 		case <-ctx.Done():
 			return
-		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, true)
+		case vpnIP := <-hm.trigger:
+			hm.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now)
+			hm.NextOutboundHandshakeTimerTick(now)
 		}
 	}
 }
@@ -140,7 +141,7 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
 	if addr.IsValid() {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
@@ -162,14 +163,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
-	c.OutboundHandshakeTimer.Advance(now)
+func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
+	hm.OutboundHandshakeTimer.Advance(now)
 	for {
-		vpnIp, has := c.OutboundHandshakeTimer.Purge()
+		vpnIp, has := hm.OutboundHandshakeTimer.Purge()
 		if !has {
 			break
 		}
-		c.handleOutbound(vpnIp, false)
+		hm.handleOutbound(vpnIp, false)
 	}
 }
 
@@ -211,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
 	}
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@@ -226,7 +227,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 
 	hh.lastRemotes = remotes
 
-	// TODO: this will generate a load of queries for hosts with only 1 ip
+	// This will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
 	if len(remotes) <= 1 && hh.counter == 5 {
@@ -293,59 +294,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
+			// Don't relay to myself
+			if relay == vpnIp {
 				continue
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+
+			// Don't relay through the host I'm trying to connect to
+			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
+			if found {
+				continue
+			}
+
+			relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
 			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hm.f.Handshake(relay)
 				continue
 			}
-			// Check the relay HostInfo to see if we already established a relay through it
-			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
-				switch existingRelay.State {
-				case Established:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
-				case Requested:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
-
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
-					// Re-send the CreateRelay request, in case the previous one was lost.
-					m := NebulaControl{
-						Type:                NebulaControl_CreateRelayRequest,
-						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
-					}
-					msg, err := m.Marshal()
-					if err != nil {
-						hostinfo.logger(hm.l).
-							WithError(err).
-							Error("Failed to marshal Control message to create relay")
-					} else {
-						// This must send over the hostinfo, not over hm.Hosts[ip]
-						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
-							"relayTo":             vpnIp,
-							"initiatorRelayIndex": existingRelay.LocalIndex,
-							"relay":               relay}).
-							Info("send CreateRelayRequest")
-					}
-				default:
-					hostinfo.logger(hm.l).
-						WithField("vpnIp", vpnIp).
-						WithField("state", existingRelay.State).
-						WithField("relay", relayHostInfo.vpnIp).
-						Errorf("Relay unexpected state")
-				}
-			} else {
+			// Check the relay HostInfo to see if we already established a relay through
+			existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
+			if !ok {
 				// No relays exist or requested yet.
 				if relayHostInfo.remote.IsValid() {
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
@@ -353,16 +321,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
+
+					switch relayHostInfo.GetCert().Certificate.Version() {
+					case cert.Version1:
+						if !hm.f.myVpnAddrs[0].Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+							continue
+						}
+
+						if !vpnIp.Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+							continue
+						}
+
+						b := hm.f.myVpnAddrs[0].As4()
+						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+						b = vpnIp.As4()
+						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					case cert.Version2:
+						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					default:
+						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						continue
+					}
+
 					msg, err := m.Marshal()
 					if err != nil {
 						hostinfo.logger(hm.l).
@@ -371,13 +358,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
+							"relayFrom":           hm.f.myVpnAddrs[0],
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
 							Info("send CreateRelayRequest")
 					}
 				}
+				continue
+			}
+
+			switch existingRelay.State {
+			case Established:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+				hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+			case Disestablished:
+				// Mark this relay as 'requested'
+				relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
+				fallthrough
+			case Requested:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+				// Re-send the CreateRelay request, in case the previous one was lost.
+				m := NebulaControl{
+					Type:                NebulaControl_CreateRelayRequest,
+					InitiatorRelayIndex: existingRelay.LocalIndex,
+				}
+
+				switch relayHostInfo.GetCert().Certificate.Version() {
+				case cert.Version1:
+					if !hm.f.myVpnAddrs[0].Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+						continue
+					}
+
+					if !vpnIp.Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+						continue
+					}
+
+					b := hm.f.myVpnAddrs[0].As4()
+					m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+					b = vpnIp.As4()
+					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+				case cert.Version2:
+					m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+				default:
+					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+					continue
+				}
+				msg, err := m.Marshal()
+				if err != nil {
+					hostinfo.logger(hm.l).
+						WithError(err).
+						Error("Failed to marshal Control message to create relay")
+				} else {
+					// This must send over the hostinfo, not over hm.Hosts[ip]
+					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+					hm.l.WithFields(logrus.Fields{
+						"relayFrom":           hm.f.myVpnAddrs[0],
+						"relayTo":             vpnIp,
+						"initiatorRelayIndex": existingRelay.LocalIndex,
+						"relay":               relay}).
+						Info("send CreateRelayRequest")
+				}
+			case PeerRequested:
+				// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
+				fallthrough
+			default:
+				hostinfo.logger(hm.l).
+					WithField("vpnIp", vpnIp).
+					WithField("state", existingRelay.State).
+					WithField("relay", relay).
+					Errorf("Relay unexpected state")
+
 			}
 		}
 	}
@@ -407,10 +461,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 }
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 
-	if hh, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnAddr]; ok {
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
 			cacheCb(hh)
@@ -420,12 +474,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 	}
 
 	hostinfo := &HostInfo{
-		vpnIp:           vpnIp,
+		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}
 
@@ -433,9 +487,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 		hostinfo:  hostinfo,
 		startTime: time.Now(),
 	}
-	hm.vpnIps[vpnIp] = hh
+	hm.vpnIps[vpnAddr] = hh
 	hm.metricInitiated.Inc(1)
-	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
+	hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
 
 	if cacheCb != nil {
 		cacheCb(hh)
@@ -443,21 +497,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// We can trigger the handshake right now
-	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
 	if !doTrigger {
 		// Add any calculated remotes, and trigger early handshake if one found
-		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
 	}
 
 	if doTrigger {
 		select {
-		case hm.trigger <- vpnIp:
+		case hm.trigger <- vpnAddr:
 		default:
 		}
 	}
 
 	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp)
+	hm.lightHouse.QueryServer(vpnAddr)
 	return hostinfo
 }
 
@@ -478,14 +532,14 @@ var (
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
-	c.Lock()
-	defer c.Unlock()
+func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
 	if found && existingHostInfo != nil {
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
@@ -502,31 +556,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			return existingHostInfo, ErrExistingHostInfo
 		}
 
-		existingHostInfo.logger(c.l).Info("Taking new handshake")
+		existingHostInfo.logger(hm.l).Info("Taking new handshake")
 	}
 
-	existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
 	if found {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
 	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 	}
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+		hostinfo.logger(hm.l).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 	}
 
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 }
 
@@ -544,7 +598,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(hm.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 	}
 
@@ -581,31 +635,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 }
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	c.Lock()
-	defer c.Unlock()
-	c.unlockedDeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
 }
 
-func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	delete(c.vpnIps, hostinfo.vpnIp)
-	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
+func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	for _, addr := range hostinfo.vpnAddrs {
+		delete(hm.vpnIps, addr)
 	}
 
-	delete(c.indexes, hostinfo.localIndexId)
-	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HandshakeHostInfo{}
+	if len(hm.vpnIps) == 0 {
+		hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+	delete(hm.indexes, hostinfo.localIndexId)
+	if len(hm.indexes) == 0 {
+		hm.indexes = map[uint32]*HandshakeHostInfo{}
+	}
+
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Pending hostmap hostInfo deleted")
 	}
 }
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
+func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 		return hh.hostinfo
@@ -634,37 +691,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 }
 
-func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
-	return c.mainHostMap.GetPreferredRanges()
+func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
+	return hm.mainHostMap.GetPreferredRanges()
 }
 
-func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
-	for _, v := range c.vpnIps {
+	for _, v := range hm.vpnIps {
 		f(v.hostinfo)
 	}
 }
 
-func (c *HandshakeManager) ForEachIndex(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
-	for _, v := range c.indexes {
+	for _, v := range hm.indexes {
 		f(v.hostinfo)
 	}
 }
 
-func (c *HandshakeManager) EmitStats() {
-	c.RLock()
-	hostLen := len(c.vpnIps)
-	indexLen := len(c.indexes)
-	c.RUnlock()
+func (hm *HandshakeManager) EmitStats() {
+	hm.RLock()
+	hostLen := len(hm.vpnIps)
+	indexLen := len(hm.indexes)
+	hm.RUnlock()
 
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
-	c.mainHostMap.EmitStats()
+	hm.mainHostMap.EmitStats()
 }
 
 // Utility functions below

+ 18 - 11
handshake_manager_test.go

@@ -14,21 +14,20 @@ import (
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	ip := netip.MustParseAddr("172.1.1.2")
 
 	preferredRanges := []netip.Prefix{localrange}
-	mainHM := newHostMap(l, vpncidr)
+	mainHM := newHostMap(l)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
 	lh := newTestLighthouse()
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &cert.NebulaCertificate{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -42,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 
-	i.remotes = NewRemoteList(nil)
+	i.remotes = NewRemoteList([]netip.Addr{}, nil)
 
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
@@ -80,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 type mockEncWriter struct {
 }
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
 	return
 }
 
-func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
 	return
 }
 
-func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
 	return
 }
 
-func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
+func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
+
+func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
+	return nil
+}
+
+func (mw *mockEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: cert.Version2}
+}

+ 171 - 104
hostmap.go

@@ -35,6 +35,7 @@ const (
 	Requested = iota
 	PeerRequested
 	Established
+	Disestablished
 )
 
 const (
@@ -48,7 +49,7 @@ type Relay struct {
 	State       int
 	LocalIndex  uint32
 	RemoteIndex uint32
-	PeerIp      netip.Addr
+	PeerAddr    netip.Addr
 }
 
 type HostMap struct {
@@ -58,7 +59,6 @@ type HostMap struct {
 	RemoteIndexes   map[uint32]*HostInfo
 	Hosts           map[netip.Addr]*HostInfo
 	preferredRanges atomic.Pointer[[]netip.Prefix]
-	vpnCIDR         netip.Prefix
 	l               *logrus.Logger
 }
 
@@ -68,9 +68,12 @@ type HostMap struct {
 type RelayState struct {
 	sync.RWMutex
 
-	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
+	relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
+	// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
+	// the RelayState Lock held)
+	relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx  map[uint32]*Relay     // Maps a local index to some Relay info
 }
 
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
@@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	delete(rs.relays, ip)
 }
 
+func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByAddr[vpnIp]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
+func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByIdx[idx]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
 func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	rs.RLock()
 	defer rs.RUnlock()
@@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 }
 
-func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[ip]
+	r, ok := rs.relayForByAddr[addr]
 	return r, ok
 }
 
@@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
 func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	defer rs.RUnlock()
-	currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
-	for relayIp := range rs.relayForByIp {
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
+	for relayIp := range rs.relayForByAddr {
 		currentRelays = append(currentRelays, relayIp)
 	}
 	return currentRelays
@@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	defer rs.Unlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	if !ok {
 		return false
 	}
@@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return true
 }
 
@@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return &newRelay, true
 }
 
 func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	return r, ok
 }
 
@@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	defer rs.Unlock()
-	rs.relayForByIp[ip] = r
+	rs.relayForByAddr[ip] = r
 	rs.relayForByIdx[idx] = r
 }
 
@@ -190,10 +215,16 @@ type HostInfo struct {
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	localIndexId    uint32
-	vpnIp           netip.Addr
-	recvError       atomic.Uint32
-	remoteCidr      *bart.Table[struct{}]
-	relayState      RelayState
+
+	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
+	// The host may have other vpn addresses that are outside our
+	// vpn networks but were removed because they are not usable
+	vpnAddrs  []netip.Addr
+	recvError atomic.Uint32
+
+	// networks are both all vpn and unsafe networks assigned to this host
+	networks   *bart.Table[struct{}]
+	relayState RelayState
 
 	// If true, we should send to this remote using multiport
 	multiportTx bool
@@ -247,28 +278,26 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
-	hm := newHostMap(l, vpnCIDR)
+func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
+	hm := newHostMap(l)
 
 	hm.reload(c, true)
 	c.RegisterReloadCallback(func(c *config.C) {
 		hm.reload(c, false)
 	})
 
-	l.WithField("network", hm.vpnCIDR.String()).
-		WithField("preferredRanges", hm.GetPreferredRanges()).
+	l.WithField("preferredRanges", hm.GetPreferredRanges()).
 		Info("Main HostMap created")
 
 	return hm
 }
 
-func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
+func newHostMap(l *logrus.Logger) *HostMap {
 	return &HostMap{
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		Hosts:         map[netip.Addr]*HostInfo{},
-		vpnCIDR:       vpnCIDR,
 		l:             l,
 	}
 }
@@ -311,17 +340,6 @@ func (hm *HostMap) EmitStats() {
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 
-func (hm *HostMap) RemoveRelay(localIdx uint32) {
-	hm.Lock()
-	_, ok := hm.Relays[localIdx]
-	if !ok {
-		hm.Unlock()
-		return
-	}
-	delete(hm.Relays, localIdx)
-	hm.Unlock()
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
@@ -341,48 +359,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 }
 
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
-	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	// Get the current primary, if it exists
+	oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
+
+	// Every address in the hostinfo gets elevated to primary
+	for _, vpnAddr := range hostinfo.vpnAddrs {
+		//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
+		// indexes so it should be fine.
+		hm.Hosts[vpnAddr] = hostinfo
+	}
+
+	// If we are already primary then we won't bother re-linking
 	if oldHostinfo == hostinfo {
 		return
 	}
 
+	// Unlink this hostinfo
 	if hostinfo.prev != nil {
 		hostinfo.prev.next = hostinfo.next
 	}
-
 	if hostinfo.next != nil {
 		hostinfo.next.prev = hostinfo.prev
 	}
 
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
+	// If there wasn't a previous primary then clear out any links
 	if oldHostinfo == nil {
+		hostinfo.next = nil
+		hostinfo.prev = nil
 		return
 	}
 
+	// Relink the hostinfo as primary
 	hostinfo.next = oldHostinfo
 	oldHostinfo.prev = hostinfo
 	hostinfo.prev = nil
 }
 
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	for _, addr := range hostinfo.vpnAddrs {
+		h := hm.Hosts[addr]
+		for h != nil {
+			if h == hostinfo {
+				hm.unlockedInnerDeleteHostInfo(h, addr)
+			}
+			h = h.next
+		}
+	}
+}
+
+func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
+	primary, ok := hm.Hosts[addr]
+	isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
 	if ok && primary == hostinfo {
-		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
-		delete(hm.Hosts, hostinfo.vpnIp)
+		// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, addr)
 		if len(hm.Hosts) == 0 {
 			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 
 		if hostinfo.next != nil {
-			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
-			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
+			hm.Hosts[addr] = hostinfo.next
 			// It is primary, there is no previous hostinfo now
 			hostinfo.next.prev = nil
 		}
 
 	} else {
-		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		// Relink if we were in the middle of multiple hostinfos for this vpn addr
 		if hostinfo.prev != nil {
 			hostinfo.prev.next = hostinfo.next
 		}
@@ -412,10 +455,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
 
+	if isLastHostinfo {
+		// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
+		// hops as 'Requested' so that new relay tunnels are created in the future.
+		hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
+	}
+	// Clean up any local relay indexes for which I am acting as a relay hop
 	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
 		delete(hm.Relays, localRelayIdx)
 	}
@@ -454,11 +503,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 }
 
-func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
-	return hm.queryVpnIp(vpnIp, nil)
+func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
+	return hm.queryVpnAddr(vpnIp, nil)
 }
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -466,17 +515,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
 	if !ok {
 		return nil, nil, errors.New("unable to find host")
 	}
+
 	for h != nil {
-		r, ok := h.relayState.QueryRelayForByIp(targetIp)
-		if ok && r.State == Established {
-			return h, r, nil
+		for _, targetIp := range targetIps {
+			r, ok := h.relayState.QueryRelayForByIp(targetIp)
+			if ok && r.State == Established {
+				return h, r, nil
+			}
 		}
 		h = h.next
 	}
+
 	return nil, nil, errors.New("unable to find host with relay")
 }
 
-func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
+	for _, relayHostIp := range hi.relayState.CopyRelayIps() {
+		if h, ok := hm.Hosts[relayHostIp]; ok {
+			for h != nil {
+				h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+				h = h.next
+			}
+		}
+	}
+	for _, rs := range hi.relayState.CopyAllRelayFor() {
+		if rs.Type == ForwardingType {
+			if h, ok := hm.Hosts[rs.PeerAddr]; ok {
+				for h != nil {
+					h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+					h = h.next
+				}
+			}
+		}
+	}
+}
+
+func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -497,25 +571,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
-		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
+		dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
 	}
-
-	existing := hm.Hosts[hostinfo.vpnIp]
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
-	if existing != nil {
-		hostinfo.next = existing
-		existing.prev = hostinfo
+	for _, addr := range hostinfo.vpnAddrs {
+		hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
 	}
 
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
+		hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
 			Debug("Hostmap vpnIp added")
 	}
+}
+
+func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
+	existing := hm.Hosts[vpnAddr]
+	hm.Hosts[vpnAddr] = hostinfo
+
+	if existing != nil && existing != hostinfo {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
 
 	i := 1
 	check := hostinfo
@@ -533,7 +612,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	return *hm.preferredRanges.Load()
 }
 
-func (hm *HostMap) ForEachVpnIp(f controlEach) {
+func (hm *HostMap) ForEachVpnAddr(f controlEach) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -587,11 +666,11 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
 		}
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp)
+		ifce.lightHouse.QueryServer(i.vpnAddrs[0])
 	}
 }
 
-func (i *HostInfo) GetCert() *cert.NebulaCertificate {
+func (i *HostInfo) GetCert() *cert.CachedCertificate {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
 	}
@@ -602,7 +681,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	if i.remote != remote {
 		i.remote = remote
-		i.remotes.LearnRemote(i.vpnIp, remote)
+		i.remotes.LearnRemote(i.vpnAddrs[0], remote)
 	}
 }
 
@@ -653,29 +732,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 }
 
-func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
-	if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
+func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
+	if len(networks) == 1 && len(unsafeNetworks) == 0 {
 		// Simple case, no CIDRTree needed
 		return
 	}
 
-	remoteCidr := new(bart.Table[struct{}])
-	for _, ip := range c.Details.Ips {
-		//TODO: IPV6-WORK what to do when ip is invalid?
-		nip, _ := netip.AddrFromSlice(ip.IP)
-		nip = nip.Unmap()
-		bits, _ := ip.Mask.Size()
-		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
+	i.networks = new(bart.Table[struct{}])
+	for _, network := range networks {
+		i.networks.Insert(network, struct{}{})
 	}
 
-	for _, n := range c.Details.Subnets {
-		//TODO: IPV6-WORK what to do when ip is invalid?
-		nip, _ := netip.AddrFromSlice(n.IP)
-		nip = nip.Unmap()
-		bits, _ := n.Mask.Size()
-		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
+	for _, network := range unsafeNetworks {
+		i.networks.Insert(network, struct{}{})
 	}
-	i.remoteCidr = remoteCidr
 }
 
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@@ -683,13 +753,13 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 	}
 
-	li := l.WithField("vpnIp", i.vpnIp).
+	li := l.WithField("vpnAddrs", i.vpnAddrs).
 		WithField("localIndex", i.localIndexId).
 		WithField("remoteIndex", i.remoteIndexId)
 
 	if connState := i.ConnectionState; connState != nil {
 		if peerCert := connState.peerCert; peerCert != nil {
-			li = li.WithField("certName", peerCert.Details.Name)
+			li = li.WithField("certName", peerCert.Certificate.Name())
 		}
 	}
 
@@ -698,9 +768,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 // Utility functions
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
+func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
-	var ips []netip.Addr
+	var finalAddrs []netip.Addr
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
@@ -712,39 +782,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 			continue
 		}
 		addrs, _ := i.Addrs()
-		for _, addr := range addrs {
-			var ip net.IP
-			switch v := addr.(type) {
+		for _, rawAddr := range addrs {
+			var addr netip.Addr
+			switch v := rawAddr.(type) {
 			case *net.IPNet:
 				//continue
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			case *net.IPAddr:
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			}
 
-			nip, ok := netip.AddrFromSlice(ip)
-			if !ok {
+			if !addr.IsValid() {
 				if l.Level >= logrus.DebugLevel {
-					l.WithField("localIp", ip).Debug("ip was invalid for netip")
+					l.WithField("localAddr", rawAddr).Debug("addr was invalid")
 				}
 				continue
 			}
-			nip = nip.Unmap()
+			addr = addr.Unmap()
 
-			//TODO: Filtering out link local for now, this is probably the most correct thing
-			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
-				allow := allowList.Allow(nip)
+			if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
+				isAllowed := allowList.Allow(addr)
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
 				}
-				if !allow {
+				if !isAllowed {
 					continue
 				}
 
-				ips = append(ips, nip)
+				finalAddrs = append(finalAddrs, addr)
 			}
 		}
 	}
-	return ips
+	return finalAddrs
 }

+ 23 - 33
hostmap_test.go

@@ -11,17 +11,14 @@ import (
 
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
 
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
-	h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
-	h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
+	h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
+	h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
@@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 }
 
@@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 
-	hm := NewHostMapFromConfig(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-		c,
-	)
+	hm := NewHostMapFromConfig(l, c)
 
 	toS := func(ipn []netip.Prefix) []string {
 		var s []string

+ 2 - 2
hostmap_tester.go

@@ -9,8 +9,8 @@ import (
 	"net/netip"
 )
 
-func (i *HostInfo) GetVpnIp() netip.Addr {
-	return i.vpnIp
+func (i *HostInfo) GetVpnAddrs() []netip.Addr {
+	return i.vpnAddrs
 }
 
 func (i *HostInfo) GetLocalIndex() uint32 {

+ 31 - 27
inside.go

@@ -21,14 +21,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
-		return
+	if f.dropLocalBroadcast {
+		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
+		if found {
+			return
+		}
 	}
 
-	if fwPacket.RemoteIP == f.myVpnNet.Addr() {
+	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
+	if found {
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
-		// routes packets from the Nebula IP to the Nebula IP through the Nebula
+		// routes packets from the Nebula addr to the Nebula addr through the Nebula
 		// TUN device.
 		if immediatelyForwardToSelf {
 			_, err := f.readers[q].Write(packet)
@@ -37,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 			}
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
-		// routes packets from the nebula IP to the nebula IP through the loopback device.
+		// routes packets from the nebula addr to the nebula addr through the loopback device.
 		return
 	}
 
 	// Ignore multicast packets
-	if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
+	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", fwPacket.RemoteIP).
+			f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
 				WithField("fwPacket", fwPacket).
-				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
+				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
 		return
 	}
@@ -118,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil)
 }
 
-func (f *Interface) Handshake(vpnIp netip.Addr) {
-	f.getOrHandshake(vpnIp, nil)
+func (f *Interface) Handshake(vpnAddr netip.Addr) {
+	f.getOrHandshake(vpnAddr, nil)
 }
 
-// getOrHandshake returns nil if the vpnIp is not routable.
+// getOrHandshake returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !f.myVpnNet.Contains(vpnIp) {
-		vpnIp = f.inside.RouteFor(vpnIp)
-		if !vpnIp.IsValid() {
+func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
+	if !found {
+		vpnAddr = f.inside.RouteFor(vpnAddr)
+		if !vpnAddr.IsValid() {
 			return nil, false
 		}
 	}
 
-	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
+	return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -157,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil)
 }
 
-// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
+	hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", vpnIp).
-				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
+			f.l.WithField("vpnAddr", vpnAddr).
+				Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
 		}
 		return
 	}
@@ -259,7 +264,6 @@ func (f *Interface) SendVia(via *HostInfo,
 
 func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int, udpPortGetter udp.SendPortGetter) {
 	if ci.eKey == nil {
-		//TODO: log warning
 		return
 	}
 
@@ -303,14 +307,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	f.connectionManager.Out(hostinfo.localIndexId)
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
-	// all our IPs and enable a faster roaming.
+	// all our addrs and enable a faster roaming.
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp)
+		f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 	}
 
@@ -354,7 +358,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
 			if err != nil {
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

+ 84 - 77
interface.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"context"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -12,6 +11,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -28,7 +28,6 @@ type InterfaceConfig struct {
 	Outside                 udp.Conn
 	Inside                  overlay.Device
 	pki                     *PKI
-	Cipher                  string
 	Firewall                *Firewall
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
@@ -52,25 +51,27 @@ type InterfaceConfig struct {
 }
 
 type Interface struct {
-	hostMap            *HostMap
-	outside            udp.Conn
-	inside             overlay.Device
-	pki                *PKI
-	cipher             string
-	firewall           *Firewall
-	connectionManager  *connectionManager
-	handshakeManager   *HandshakeManager
-	serveDns           bool
-	createTime         time.Time
-	lightHouse         *LightHouse
-	myBroadcastAddr    netip.Addr
-	myVpnNet           netip.Prefix
-	dropLocalBroadcast bool
-	dropMulticast      bool
-	routines           int
-	disconnectInvalid  atomic.Bool
-	closed             atomic.Bool
-	relayManager       *relayManager
+	hostMap               *HostMap
+	outside               udp.Conn
+	inside                overlay.Device
+	pki                   *PKI
+	firewall              *Firewall
+	connectionManager     *connectionManager
+	handshakeManager      *HandshakeManager
+	serveDns              bool
+	createTime            time.Time
+	lightHouse            *LightHouse
+	myBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
+	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	dropLocalBroadcast    bool
+	dropMulticast         bool
+	routines              int
+	disconnectInvalid     atomic.Bool
+	closed                atomic.Bool
+	relayManager          *relayManager
 
 	tryPromoteEvery atomic.Uint32
 	reQueryEvery    atomic.Uint32
@@ -114,9 +115,11 @@ type EncWriter interface {
 		out []byte,
 		nocopy bool,
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
+	SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp netip.Addr)
+	Handshake(vpnAddr netip.Addr)
+	GetHostInfo(vpnAddr netip.Addr) *HostInfo
+	GetCertState() *CertState
 }
 
 type sendRecvErrorConfig uint8
@@ -127,10 +130,10 @@ const (
 	sendRecvErrorPrivate
 )
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
 	switch s {
 	case sendRecvErrorPrivate:
-		return ip.Addr().IsPrivate()
+		return endpoint.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 		return true
 	case sendRecvErrorNever:
@@ -167,47 +170,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 	}
 
-	certificate := c.pki.GetCertState().Certificate
-
-	myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
-	if !ok {
-		return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
-	}
-
-	myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
-	if !ok {
-		return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
-	}
-
-	myVpnAddr = myVpnAddr.Unmap()
-	myVpnMask = myVpnMask.Unmap()
-
-	if myVpnAddr.BitLen() != myVpnMask.BitLen() {
-		return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
-	}
-
-	ones, _ := certificate.Details.Ips[0].Mask.Size()
-	myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
-
+	cs := c.pki.getCertState()
 	ifce := &Interface{
-		pki:                c.pki,
-		hostMap:            c.HostMap,
-		outside:            c.Outside,
-		inside:             c.Inside,
-		cipher:             c.Cipher,
-		firewall:           c.Firewall,
-		serveDns:           c.ServeDns,
-		handshakeManager:   c.HandshakeManager,
-		createTime:         time.Now(),
-		lightHouse:         c.lightHouse,
-		dropLocalBroadcast: c.DropLocalBroadcast,
-		dropMulticast:      c.DropMulticast,
-		routines:           c.routines,
-		version:            c.version,
-		writers:            make([]udp.Conn, c.routines),
-		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnNet:           myVpnNet,
-		relayManager:       c.relayManager,
+		pki:                   c.pki,
+		hostMap:               c.HostMap,
+		outside:               c.Outside,
+		inside:                c.Inside,
+		firewall:              c.Firewall,
+		serveDns:              c.ServeDns,
+		handshakeManager:      c.HandshakeManager,
+		createTime:            time.Now(),
+		lightHouse:            c.lightHouse,
+		dropLocalBroadcast:    c.DropLocalBroadcast,
+		dropMulticast:         c.DropMulticast,
+		routines:              c.routines,
+		version:               c.version,
+		writers:               make([]udp.Conn, c.routines),
+		readers:               make([]io.ReadWriteCloser, c.routines),
+		myVpnNetworks:         cs.myVpnNetworks,
+		myVpnNetworksTable:    cs.myVpnNetworksTable,
+		myVpnAddrs:            cs.myVpnAddrs,
+		myVpnAddrsTable:       cs.myVpnAddrsTable,
+		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
+		relayManager:          c.relayManager,
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
@@ -221,12 +206,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 	}
 
-	if myVpnAddr.Is4() {
-		addr := myVpnNet.Masked().Addr().As4()
-		binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
-		ifce.myBroadcastAddr = netip.AddrFrom4(addr)
-	}
-
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -247,7 +226,7 @@ func (f *Interface) activate() {
 		f.l.WithError(err).Error("Failed to get udp listen address")
 	}
 
-	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
+	f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
@@ -290,16 +269,22 @@ func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 
 	var li udp.Conn
-	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 		li = f.writers[i]
 	} else {
 		li = f.outside
 	}
 
+	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
+	plaintext := make([]byte, udp.MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+
+	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
+		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+	})
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -356,7 +341,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
@@ -441,6 +426,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	var rawStats func()
 
 	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
+	certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
+	certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
 
 	for {
 		select {
@@ -450,17 +437,37 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
+
+			certState := f.pki.getCertState()
+			defaultCrt := certState.GetDefaultCertificate()
+			certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
+			certDefaultVersion.Update(int64(defaultCrt.Version()))
+
 			if f.udpRaw != nil {
 				if rawStats == nil {
 					rawStats = udp.NewRawStatsEmitter(f.udpRaw)
 				}
 				rawStats()
 			}
-			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
+
+			// Report the max certificate version we are capable of using
+			if certState.v2Cert != nil {
+				certMaxVersion.Update(int64(certState.v2Cert.Version()))
+			} else {
+				certMaxVersion.Update(int64(certState.v1Cert.Version()))
+			}
 		}
 	}
 }
 
+func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return f.hostMap.QueryVpnAddr(vpnIp)
+}
+
+func (f *Interface) GetCertState() *CertState {
+	return f.pki.getCertState()
+}
+
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 

+ 0 - 2
iputil/packet.go

@@ -6,8 +6,6 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
-//TODO: IPV6-WORK can probably delete this
-
 const (
 	// Need 96 bytes for the largest reject packet:
 	// - 20 byte ipv4 header

File diff suppressed because it is too large
+ 364 - 240
lighthouse.go


+ 173 - 146
lighthouse_test.go

@@ -7,6 +7,8 @@ import (
 	"net/netip"
 	"testing"
 
+	"github.com/gaissmai/bart"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
@@ -14,62 +16,51 @@ import (
 	"gopkg.in/yaml.v2"
 )
 
-//TODO: Add a test to ensure udpAddr is copied and not reused
-
 func TestOldIPv4Only(t *testing.T) {
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
-	var m Ip4AndPort
+	var m V4AddrPort
 	err := m.Unmarshal(b)
 	assert.NoError(t, err)
 	ip := netip.MustParseAddr("10.1.1.1")
 	bp := ip.As4()
-	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
-}
-
-func TestNewLhQuery(t *testing.T) {
-	myIp, err := netip.ParseAddr("192.1.1.1")
-	assert.NoError(t, err)
-
-	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIp)
-
-	// The result should be a nebulameta protobuf
-	assert.IsType(t, &NebulaMeta{}, a)
-
-	// It should also Marshal fine
-	b, err := a.Marshal()
-	assert.Nil(t, err)
-
-	// and then Unmarshal fine
-	n := &NebulaMeta{}
-	err = n.Unmarshal(b)
-	assert.Nil(t, err)
-
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
 }
 
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
@@ -79,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	}
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
 
@@ -99,9 +90,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	if !assert.NoError(b, err) {
 		b.Fatal()
 	}
@@ -110,46 +107,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
 
 	vpnIp3 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp3] = NewRemoteList(nil)
+	lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
 	lh.addrMap[vpnIp3].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
-			NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
+			netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
 	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
 	vpnIp2 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp2] = NewRemoteList(nil)
+	lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
 	lh.addrMap[vpnIp2].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
-			NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
+			netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	mw := &mockEncWriter{}
 
+	hi := []netip.Addr{vpnIp2}
 	b.Run("notfound", func(b *testing.B) {
 		lhh := lh.NewRequestHandler()
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
-				VpnIp:       4,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  4,
+				V4AddrPorts: nil,
 			},
 		}
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 	})
 	b.Run("found", func(b *testing.B) {
@@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
-				VpnIp:       3,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  3,
+				V4AddrPorts: nil,
 			},
 		}
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 	})
 }
@@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	lh.ifce = &mockEncWriter{}
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
 
 	// Ensure we don't accumulate addresses
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
 
 	// Grow it back to 2
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	// Update a different host and ask about it
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Have both hosts ask about the other
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Make sure we didn't get changed
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	// Ensure proper ordering and limiting
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 		t,
-		r.msg.Details.Ip4AndPorts,
+		r.msg.Details.V4AddrPorts,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 
@@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	good := netip.MustParseAddrPort("1.128.0.99:4242")
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
 }
 
 func TestLighthouse_reload(t *testing.T) {
@@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 
 	nc := map[interface{}]interface{}{
@@ -290,13 +306,16 @@ func TestLighthouse_reload(t *testing.T) {
 }
 
 func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
-	//TODO: IPV6-WORK
-	bip := queryVpnIp.As4()
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostQuery,
-		Details: &NebulaMetaDetails{
-			VpnIp: binary.BigEndian.Uint32(bip[:]),
-		},
+		Type:    NebulaMeta_HostQuery,
+		Details: &NebulaMetaDetails{},
+	}
+
+	if queryVpnIp.Is4() {
+		bip := queryVpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
 	}
 
 	b, err := req.Marshal()
@@ -308,23 +327,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
 	w := &testEncWriter{
 		metaFilter: &filter,
 	}
-	lhh.HandleRequest(fromAddr, myVpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
 	return w.lastReply
 }
 
 func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
-	//TODO: IPV6-WORK
-	bip := vpnIp.As4()
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostUpdateNotification,
-		Details: &NebulaMetaDetails{
-			VpnIp:       binary.BigEndian.Uint32(bip[:]),
-			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
-		},
+		Type:    NebulaMeta_HostUpdateNotification,
+		Details: &NebulaMetaDetails{},
 	}
 
-	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
+	if vpnIp.Is4() {
+		bip := vpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
+	}
+
+	for _, v := range addrs {
+		if v.Addr().Is4() {
+			req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
+		} else {
+			req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
+		}
 	}
 
 	b, err := req.Marshal()
@@ -333,75 +358,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
 	}
 
 	w := &testEncWriter{}
-	lhh.HandleRequest(fromAddr, vpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
 }
 
-//TODO: this is a RemoteList test
-//func Test_lhRemoteAllowList(t *testing.T) {
-//	l := NewLogger()
-//	c := NewConfig(l)
-//	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
-//		"10.20.0.0/12": false,
-//	}
-//	allowList, err := c.GetAllowList("remoteallowlist", false)
-//	assert.Nil(t, err)
-//
-//	lh1 := "10.128.0.2"
-//	lh1IP := net.ParseIP(lh1)
-//
-//	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-//
-//	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-//	lh.SetRemoteAllowList(allowList)
-//
-//	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
-//	remote1IP := net.ParseIP("10.20.0.3")
-//	remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
-//	remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
-//	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
-//	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
-//
-//	// Make sure a good ip enters the cache and addrMap
-//	remote2IP := net.ParseIP("10.128.0.3")
-//	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
-//
-//	// Another good ip gets into the cache, ordering is inverted
-//	remote3IP := net.ParseIP("10.128.0.4")
-//	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
-//
-//	// If we exceed the length limit we should only have the most recent addresses
-//	addedAddrs := []*udpAddr{}
-//	for i := 0; i < 11; i++ {
-//		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
-//		lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
-//		// The first entry here is a duplicate, don't add it to the assert list
-//		if i != 0 {
-//			addedAddrs = append(addedAddrs, remoteUDPAddr)
-//		}
-//	}
-//
-//	// We should only have the last 10 of what we tried to add
-//	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
-//	assertUdpAddrInArray(
-//		t,
-//		lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
-//		addedAddrs[0],
-//		addedAddrs[1],
-//		addedAddrs[2],
-//		addedAddrs[3],
-//		addedAddrs[4],
-//		addedAddrs[5],
-//		addedAddrs[6],
-//		addedAddrs[7],
-//		addedAddrs[8],
-//		addedAddrs[9],
-//	)
-//}
-
 type testLhReply struct {
 	nebType    header.MessageType
 	nebSubType header.MessageSubType
@@ -410,8 +369,9 @@ type testLhReply struct {
 }
 
 type testEncWriter struct {
-	lastReply  testLhReply
-	metaFilter *NebulaMeta_MessageType
+	lastReply       testLhReply
+	metaFilter      *NebulaMeta_MessageType
+	protocolVersion cert.Version
 }
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@@ -426,7 +386,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 		tw.lastReply = testLhReply{
 			nebType:    t,
 			nebSubType: st,
-			vpnIp:      hostinfo.vpnIp,
+			vpnIp:      hostinfo.vpnAddrs[0],
 			msg:        msg,
 		}
 	}
@@ -436,7 +396,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 }
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -453,17 +413,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	}
 }
 
+func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return nil
+}
+
+func (tw *testEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: tw.protocolVersion}
+}
+
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
+func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 		return
 	}
 
 	for k, w := range want {
-		//TODO: IPV6-WORK
-		h := AddrPortFromIp4AndPort(have[k])
+		h := protoV4AddrPortToNetAddrPort(have[k])
 		if !(h == w) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 	}
 }
+
+func Test_findNetworkUnion(t *testing.T) {
+	var out netip.Addr
+	var ok bool
+
+	tenDot := netip.MustParsePrefix("10.0.0.0/8")
+	oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
+	fe80 := netip.MustParsePrefix("fe80::/8")
+	fc00 := netip.MustParsePrefix("fc00::/7")
+
+	a1 := netip.MustParseAddr("10.0.0.1")
+	afe81 := netip.MustParseAddr("fe80::1")
+
+	//simple
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed lengths
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed family
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//ordering
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//some mismatches
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//falsey cases
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+}

+ 5 - 34
main.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"context"
-	"encoding/binary"
 	"fmt"
 	"net"
 	"net/netip"
@@ -61,25 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
 
-	certificate := pki.GetCertState().Certificate
-	fw, err := NewFirewallFromConfig(l, certificate, c)
+	fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
-	ones, _ := certificate.Details.Ips[0].Mask.Size()
-	addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
-	if !ok {
-		err = util.NewContextualError(
-			"Invalid ip address in certificate",
-			m{"vpnIp": certificate.Details.Ips[0].IP},
-			nil,
-		)
-		return nil, err
-	}
-	tunCidr := netip.PrefixFrom(addr, ones)
-
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@@ -142,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			deviceFactory = overlay.NewDeviceFromConfig
 		}
 
-		tun, err = deviceFactory(c, l, tunCidr, routines)
+		tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
@@ -197,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	hostMap := NewHostMapFromConfig(l, tunCidr, c)
+	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
@@ -242,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		pki:                     pki,
-		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
@@ -264,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		l:                     l,
 	}
 
-	switch ifConfig.Cipher {
-	case "aes":
-		noiseEndianness = binary.BigEndian
-	case "chachapoly":
-		noiseEndianness = binary.LittleEndian
-	default:
-		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
-	}
-
 	var ifce *Interface
 	if !configTest {
 		ifce, err = NewInterface(ctx, ifConfig)
@@ -280,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			return nil, fmt.Errorf("failed to initialize interface: %s", err)
 		}
 
-		// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
-		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
 		lightHouse.ifce = ifce
 
@@ -326,8 +300,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		go handshakeManager.Run(ctx)
 	}
 
-	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
-	// a context so that they can exit when the context is Done.
 	statsStart, err := startStats(l, c, buildVersion, configTest)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
@@ -337,7 +309,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, nil
 	}
 
-	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
 	attachCommands(l, c, ssh, ifce)
@@ -346,7 +317,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	var dnsStart func()
 	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, c)
+		dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
 	}
 
 	return &Control{

+ 0 - 2
message_metrics.go

@@ -7,8 +7,6 @@ import (
 	"github.com/slackhq/nebula/header"
 )
 
-//TODO: this can probably move into the header package
-
 type MessageMetrics struct {
 	rx [][]metrics.Counter
 	tx [][]metrics.Counter

File diff suppressed because it is too large
+ 560 - 176
nebula.pb.go


+ 23 - 9
nebula.proto

@@ -23,19 +23,28 @@ message NebulaMeta {
 }
 
 message NebulaMetaDetails {
-  uint32 VpnIp = 1;
-  repeated Ip4AndPort Ip4AndPorts = 2;
-  repeated Ip6AndPort Ip6AndPorts = 4;
-  repeated uint32 RelayVpnIp = 5;
+  uint32 OldVpnAddr = 1 [deprecated = true];
+  Addr VpnAddr = 6;
+
+  repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
+  repeated Addr RelayVpnAddrs = 7;
+
+  repeated V4AddrPort V4AddrPorts = 2;
+  repeated V6AddrPort V6AddrPorts = 4;
   uint32 counter = 3;
 }
 
-message Ip4AndPort {
-  uint32 Ip = 1;
+message Addr {
+  uint64 Hi = 1;
+  uint64 Lo = 2;
+}
+
+message V4AddrPort {
+  uint32 Addr = 1;
   uint32 Port = 2;
 }
 
-message Ip6AndPort {
+message V6AddrPort {
   uint64 Hi = 1;
   uint64 Lo = 2;
   uint32 Port = 3;
@@ -69,6 +78,7 @@ message NebulaHandshakeDetails {
   uint32 ResponderIndex = 3;
   uint64 Cookie = 4;
   uint64 Time = 5;
+  uint32 CertVersion = 8;
 
   MultiPortDetails InitiatorMultiPort = 6;
   MultiPortDetails ResponderMultiPort = 7;
@@ -84,6 +94,10 @@ message NebulaControl {
 
   uint32 InitiatorRelayIndex = 2;
   uint32 ResponderRelayIndex = 3;
-  uint32 RelayToIp = 4;
-  uint32 RelayFromIp = 5;
+
+  uint32 OldRelayToAddr = 4 [deprecated = true];
+  uint32 OldRelayFromAddr = 5 [deprecated = true];
+
+  Addr RelayToAddr = 6;
+  Addr RelayFromAddr = 7;
 }

+ 50 - 0
noiseutil/pkcs11.go

@@ -0,0 +1,50 @@
+package noiseutil
+
+import (
+	"crypto/ecdh"
+	"fmt"
+	"strings"
+
+	"github.com/slackhq/nebula/pkclient"
+
+	"github.com/flynn/noise"
+)
+
+// DHP256PKCS11 is the NIST P-256 ECDH function
+var DHP256PKCS11 noise.DHFunc = newNISTP11Curve("P256", ecdh.P256(), 32)
+
+type nistP11Curve struct {
+	nistCurve
+}
+
+func newNISTP11Curve(name string, curve ecdh.Curve, byteLen int) nistP11Curve {
+	return nistP11Curve{
+		newNISTCurve(name, curve, byteLen),
+	}
+}
+
+func (c nistP11Curve) DH(privkey, pubkey []byte) ([]byte, error) {
+	//for this function "privkey" is actually a pkcs11 URI
+	pkStr := string(privkey)
+
+	//to set up a handshake, we need to also do non-pkcs11-DH. Handle that here.
+	if !strings.HasPrefix(pkStr, "pkcs11:") {
+		return DHP256.DH(privkey, pubkey)
+	}
+	ecdhPubKey, err := c.curve.NewPublicKey(pubkey)
+	if err != nil {
+		return nil, fmt.Errorf("unable to unmarshal pubkey: %w", err)
+	}
+
+	//this is not the most performant way to do this (a long-lived client would be better)
+	//but, it works, and helps avoid problems with stale sessions and HSMs used by multiple users.
+	client, err := pkclient.FromUrl(pkStr)
+	if err != nil {
+		return nil, err
+	}
+	defer func(client *pkclient.PKClient) {
+		_ = client.Close()
+	}(client)
+
+	return client.DeriveNoise(ecdhPubKey.Bytes())
+}

+ 156 - 137
outside.go

@@ -3,46 +3,25 @@ package nebula
 import (
 	"encoding/binary"
 	"errors"
-	"fmt"
 	"net/netip"
 	"time"
 
-	"github.com/flynn/noise"
+	"github.com/google/gopacket/layers"
+	"golang.org/x/net/ipv6"
+
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
-	"google.golang.org/protobuf/proto"
 )
 
 const (
 	minFwPacketLen = 4
 )
 
-// TODO: IPV6-WORK this can likely be removed now
-func readOutsidePackets(f *Interface) udp.EncReader {
-	return func(
-		addr netip.AddrPort,
-		out []byte,
-		packet []byte,
-		header *header.H,
-		fwPacket *firewall.Packet,
-		lhh udp.LightHouseHandlerFunc,
-		nb []byte,
-		q int,
-		localCache firewall.ConntrackCache,
-	) {
-		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
-	}
-}
-
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
-		// TODO: best if we return this and let caller log
-		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
 			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
@@ -52,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	if ip.IsValid() {
-		if f.myVpnNet.Contains(ip.Addr()) {
+		_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
+		if found {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
@@ -109,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			if !ok {
 				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
 				// its internal mapping. This should never happen.
-				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
 				return
 			}
 
@@ -121,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
-				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
+				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
 				if err != nil {
-					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
+					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
 					return
 				}
 
@@ -139,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 					}
 				} else {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
+					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 					return
 				}
 			}
@@ -156,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
-
-			//TODO: maybe after build 64 is out? 06/14/2018 - NB
-			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
 			return
 		}
 
-		lhf(ip, hostinfo.vpnIp, d)
+		lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
 
 		// Fallthrough to the bottom to record incoming traffic
 
@@ -177,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
-
-			//TODO: maybe after build 64 is out? 06/14/2018 - NB
-			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
 			return
 		}
 
@@ -229,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				Error("Failed to decrypt Control packet")
 			return
 		}
-		m := &NebulaControl{}
-		err = m.Unmarshal(d)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
-			break
-		}
 
-		f.relayManager.HandleControlMsg(hostinfo, m, f)
+		f.relayManager.HandleControlMsg(hostinfo, d, f)
 
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -253,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	if final {
-		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
-		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
+		// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
+		f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
 	}
 }
 
@@ -263,35 +231,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
-	if ip.IsValid() && hostinfo.remote != ip {
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
+	if udpAddr.IsValid() && hostinfo.remote != udpAddr {
 		if hostinfo.multiportRx {
 			// If the remote is sending with multiport, we aren't roaming unless
 			// the IP has changed
-			if hostinfo.remote.Addr().Compare(ip.Addr()) == 0 {
+			if hostinfo.remote.Addr().Compare(udpAddr.Addr()) == 0 {
 				return
 			}
 			// Keep the port from the original hostinfo, because the remote is transmitting from multiport ports
-			ip = netip.AddrPortFrom(ip.Addr(), hostinfo.remote.Port())
+			udpAddr = netip.AddrPortFrom(udpAddr.Addr(), hostinfo.remote.Port())
 		}
 
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
-		if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+
+		if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(ip)
+		hostinfo.SetRemote(udpAddr)
 	}
 
 }
@@ -311,24 +280,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
 	return true
 }
 
+var (
+	ErrPacketTooShort          = errors.New("packet is too short")
+	ErrUnknownIPVersion        = errors.New("packet is an unknown ip version")
+	ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
+	ErrIPv4PacketTooShort      = errors.New("ipv4 packet is too short")
+	ErrIPv6PacketTooShort      = errors.New("ipv6 packet is too short")
+	ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
+)
+
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
-	// Do we at least have an ipv4 header worth of data?
-	if len(data) < ipv4.HeaderLen {
-		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
+	if len(data) < 1 {
+		return ErrPacketTooShort
+	}
+
+	version := int((data[0] >> 4) & 0x0f)
+	switch version {
+	case ipv4.Version:
+		return parseV4(data, incoming, fp)
+	case ipv6.Version:
+		return parseV6(data, incoming, fp)
+	}
+	return ErrUnknownIPVersion
+}
+
+func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
+	dataLen := len(data)
+	if dataLen < ipv6.HeaderLen {
+		return ErrIPv6PacketTooShort
 	}
 
-	// Is it an ipv4 packet?
-	if int((data[0]>>4)&0x0f) != 4 {
-		return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
+	if incoming {
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
+	} else {
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
+	}
+
+	protoAt := 6             // NextHeader is at 6 bytes into the ipv6 header
+	offset := ipv6.HeaderLen // Start at the end of the ipv6 header
+	next := 0
+	for {
+		if dataLen < offset {
+			break
+		}
+
+		proto := layers.IPProtocol(data[protoAt])
+		//fmt.Println(proto, protoAt)
+		switch proto {
+		case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
+			fp.Protocol = uint8(proto)
+			fp.RemotePort = 0
+			fp.LocalPort = 0
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolTCP, layers.IPProtocolUDP:
+			if dataLen < offset+4 {
+				return ErrIPv6PacketTooShort
+			}
+
+			fp.Protocol = uint8(proto)
+			if incoming {
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			} else {
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			}
+
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolIPv6Fragment:
+			// Fragment header is 8 bytes, need at least offset+4 to read the offset field
+			if dataLen < offset+8 {
+				return ErrIPv6PacketTooShort
+			}
+
+			// Check if this is the first fragment
+			fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
+			if fragmentOffset != 0 {
+				// Non-first fragment, use what we have now and stop processing
+				fp.Protocol = data[offset]
+				fp.Fragment = true
+				fp.RemotePort = 0
+				fp.LocalPort = 0
+				return nil
+			}
+
+			// The next loop should be the transport layer since we are the first fragment
+			next = 8 // Fragment headers are always 8 bytes
+
+		case layers.IPProtocolAH:
+			// Auth headers, used by IPSec, have a different meaning for header length
+			if dataLen < offset+1 {
+				break
+			}
+
+			next = int(data[offset+1]+2) << 2
+
+		default:
+			// Normal ipv6 header length processing
+			if dataLen < offset+1 {
+				break
+			}
+
+			next = int(data[offset+1]+1) << 3
+		}
+
+		if next <= 0 {
+			// Safety check, each ipv6 header has to be at least 8 bytes
+			next = 8
+		}
+
+		protoAt = offset
+		offset = offset + next
+	}
+
+	return ErrIPv6CouldNotFindPayload
+}
+
+func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
+	// Do we at least have an ipv4 header worth of data?
+	if len(data) < ipv4.HeaderLen {
+		return ErrIPv4PacketTooShort
 	}
 
 	// Adjust our start position based on the advertised ip header length
 	ihl := int(data[0]&0x0f) << 2
 
-	// Well formed ip header length?
+	// Well-formed ip header length?
 	if ihl < ipv4.HeaderLen {
-		return fmt.Errorf("packet had an invalid header length: %v", ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 
 	// Check if this is the second or further fragment of a fragmented packet.
@@ -344,14 +430,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 		minLen += minFwPacketLen
 	}
 	if len(data) < minLen {
-		return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 
 	// Firewall packets are locally oriented
 	if incoming {
-		//TODO: IPV6-WORK
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
@@ -360,9 +445,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 	} else {
-		//TODO: IPV6-WORK
-		fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
@@ -397,8 +481,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
-		//TODO: maybe after build 64 is out? 06/14/2018 - NB
-		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
 		return false
 	}
 
@@ -445,9 +527,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
 func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
-	//TODO: this should be a signed message so we can trust that we should drop the index
 	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
-	f.outside.WriteTo(b, endpoint)
+	_ = f.outside.WriteTo(b, endpoint)
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", index).
 			WithField("udpAddr", endpoint).
@@ -481,65 +562,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 	// We also delete it from pending hostmap to allow for fast reconnect.
 	f.handshakeManager.DeleteHostInfo(hostinfo)
 }
-
-/*
-func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
-	if ci.eKey != nil {
-		//TODO: log error?
-		return
-	}
-
-	msg, err := proto.Marshal(meta)
-	if err != nil {
-		l.Debugln("failed to encode header")
-	}
-
-	c := ci.messageCounter
-	b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
-	ci.messageCounter++
-
-	msg := ci.eKey.EncryptDanger(b, nil, msg, c)
-	//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
-	f.outside.WriteTo(msg, endpoint)
-}
-*/
-
-func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
-	pk := h.PeerStatic()
-
-	if pk == nil {
-		return nil, errors.New("no peer static key was present")
-	}
-
-	if rawCertBytes == nil {
-		return nil, errors.New("provided payload was empty")
-	}
-
-	r := &cert.RawNebulaCertificate{}
-	err := proto.Unmarshal(rawCertBytes, r)
-	if err != nil {
-		return nil, fmt.Errorf("error unmarshaling cert: %s", err)
-	}
-
-	// If the Details are nil, just exit to avoid crashing
-	if r.Details == nil {
-		return nil, fmt.Errorf("certificate did not contain any details")
-	}
-
-	r.Details.PublicKey = pk
-	recombined, err := proto.Marshal(r)
-	if err != nil {
-		return nil, fmt.Errorf("error while recombining certificate: %s", err)
-	}
-
-	c, _ := cert.UnmarshalNebulaCertificate(recombined)
-	isValid, err := c.Verify(time.Now(), caPool)
-	if err != nil {
-		return c, fmt.Errorf("certificate validation failed: %s", err)
-	} else if !isValid {
-		// This case should never happen but here's to defensive programming!
-		return c, errors.New("certificate validation failed but did not return an error")
-	}
-
-	return c, nil
-}

+ 525 - 17
outside_test.go

@@ -1,10 +1,15 @@
 package nebula
 
 import (
+	"bytes"
+	"encoding/binary"
 	"net"
 	"net/netip"
 	"testing"
 
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/layers"
+
 	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
@@ -13,9 +18,15 @@ import (
 func Test_newPacket(t *testing.T) {
 	p := &firewall.Packet{}
 
-	// length fail
-	err := newPacket([]byte{0, 1}, true, p)
-	assert.EqualError(t, err, "packet is less than 20 bytes")
+	// length fails
+	err := newPacket([]byte{}, true, p)
+	assert.ErrorIs(t, err, ErrPacketTooShort)
+
+	err = newPacket([]byte{0x40}, true, p)
+	assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
+
+	err = newPacket([]byte{0x60}, true, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 	// length fail with ip options
 	h := ipv4.Header{
@@ -28,16 +39,15 @@ func Test_newPacket(t *testing.T) {
 
 	b, _ := h.Marshal()
 	err = newPacket(b, true, p)
-
-	assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 	// not an ipv4 packet
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet is not ipv4, type: 0")
+	assert.ErrorIs(t, err, ErrUnknownIPVersion)
 
 	// invalid ihl
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet had an invalid header length: 8")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 	// account for variable ip header length - incoming
 	h = ipv4.Header{
@@ -54,11 +64,12 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, true, p)
 
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemotePort, uint16(3))
-	assert.Equal(t, p.LocalPort, uint16(4))
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
+	assert.Equal(t, uint16(3), p.RemotePort)
+	assert.Equal(t, uint16(4), p.LocalPort)
+	assert.False(t, p.Fragment)
 
 	// account for variable ip header length - outgoing
 	h = ipv4.Header{
@@ -75,9 +86,506 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, false, p)
 
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemotePort, uint16(6))
-	assert.Equal(t, p.LocalPort, uint16(5))
+	assert.Equal(t, uint8(2), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
+	assert.Equal(t, uint16(6), p.RemotePort)
+	assert.Equal(t, uint16(5), p.LocalPort)
+	assert.False(t, p.Fragment)
+}
+
+func Test_newPacket_v6(t *testing.T) {
+	p := &firewall.Packet{}
+
+	// invalid ipv6
+	ip := layers.IPv6{
+		Version:  6,
+		HopLimit: 128,
+		SrcIP:    net.IPv6linklocalallrouters,
+		DstIP:    net.IPv6linklocalallnodes,
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opt := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       false,
+	}
+	err := gopacket.SerializeLayers(buffer, opt, &ip)
+	assert.NoError(t, err)
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good ICMP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolICMPv6,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	icmp := layers.ICMPv6{}
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
+	if err != nil {
+		panic(err)
+	}
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good ESP packet
+	b := buffer.Bytes()
+	b[6] = byte(layers.IPProtocolESP)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good None packet
+	b = buffer.Bytes()
+	b[6] = byte(layers.IPProtocolNoNextHeader)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// An unknown protocol packet
+	b = buffer.Bytes()
+	b[6] = 255 // 255 is a reserved protocol number
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good UDP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: firewall.ProtoUDP,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+	err = udp.SetNetworkLayerForChecksum(&ip)
+	assert.NoError(t, err)
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
+	if err != nil {
+		panic(err)
+	}
+	b = buffer.Bytes()
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short UDP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good TCP packet
+	b[6] = byte(layers.IPProtocolTCP)
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short TCP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good UDP packet with an AH header
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolAH,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	ah := layers.IPSecAH{
+		AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
+	}
+	ah.NextHeader = layers.IPProtocolUDP
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opt)
+	if err != nil {
+		panic(err)
+	}
+
+	b = buffer.Bytes()
+	ahb := serializeAH(&ah)
+	b = append(b, ahb...)
+	b = append(b, udpHeader...)
+
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Invalid AH header
+	b = buffer.Bytes()
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+}
+
+func Test_newPacket_ipv6Fragment(t *testing.T) {
+	p := &firewall.Packet{}
+
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	// First fragment
+	fragHeader1 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: true,
+		FixLengths:       true,
+	}
+
+	err := ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader1...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test first fragment incoming
+	err = newPacket(firstFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Test first fragment outgoing
+	err = newPacket(firstFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Second fragment
+	fragHeader2 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0xb9,                        // Fragment Offset high byte (185)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader2...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test second fragment incoming
+	err = newPacket(secondFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.True(t, p.Fragment)
+
+	// Test second fragment outgoing
+	err = newPacket(secondFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.True(t, p.Fragment)
+
+	// Too short of a fragment packet
+	err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+}
+
+func BenchmarkParseV6(b *testing.B) {
+	// Regular UDP packet
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolUDP,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := &layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       true,
+	}
+
+	err := gopacket.SerializeLayers(buffer, opts, ip, udp)
+	if err != nil {
+		b.Fatal(err)
+	}
+	normalPacket := buffer.Bytes()
+
+	// First Fragment packet
+	ipFrag := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	fragHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x7b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Second Fragment packet
+	fragHeader[2] = 0xb9 // offset 185
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	fp := &firewall.Packet{}
+
+	b.Run("Normal", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(normalPacket, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("FirstFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(firstFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("SecondFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(secondFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	// Evil packet
+	evilPacket := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6HopByHop,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	hopHeader := []byte{
+		uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
+		0x00,                                 // Length
+		0x00, 0x00,                           // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	lastHopHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Length
+		0x00, 0x00,                  // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	buffer.Clear()
+	err = evilPacket.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	evilBytes := buffer.Bytes()
+	for i := 0; i < 200; i++ {
+		evilBytes = append(evilBytes, hopHeader...)
+	}
+	evilBytes = append(evilBytes, lastHopHeader...)
+	evilBytes = append(evilBytes, udpHeader...)
+	evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	b.Run("200 HopByHop headers", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(evilBytes, false, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+}
+
+// Ensure authentication data is a multiple of 8 bytes by padding if necessary
+func padAuthData(authData []byte) []byte {
+	// Length of Authentication Data must be a multiple of 8 bytes
+	paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
+	if paddingLength > 0 {
+		authData = append(authData, make([]byte, paddingLength)...)
+	}
+	return authData
+}
+
+// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
+func serializeAH(ah *layers.IPSecAH) []byte {
+	buf := new(bytes.Buffer)
+
+	// Ensure Authentication Data is a multiple of 8 bytes
+	ah.AuthenticationData = padAuthData(ah.AuthenticationData)
+	// Calculate Payload Length (in 32-bit words, minus 2)
+	payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
+
+	// Serialize fields
+	if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
+		panic(err)
+	}
+	if len(ah.AuthenticationData) > 0 {
+		if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
+			panic(err)
+		}
+	}
+
+	return buf.Bytes()
 }

+ 1 - 1
overlay/device.go

@@ -8,7 +8,7 @@ import (
 type Device interface {
 	io.ReadWriteCloser
 	Activate() error
-	Cidr() netip.Prefix
+	Networks() []netip.Prefix
 	Name() string
 	RouteFor(netip.Addr) netip.Addr
 	NewMultiQueueReader() (io.ReadWriteCloser, error)

+ 22 - 12
overlay/route.go

@@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
 	return routeTree, nil
 }
 
-func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 
 	r := c.Get("tun.routes")
@@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 		}
 
-		if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
+		found := false
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
+				found = true
+				break
+			}
+		}
+
+		if !found {
 			return nil, fmt.Errorf(
-				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
+				"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
 				i+1,
 				r.Cidr.String(),
-				network.String(),
+				networks,
 			)
 		}
 
@@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	return routes, nil
 }
 
-func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 
 	r := c.Get("tun.unsafe_routes")
@@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 		}
 
-		if network.Contains(r.Cidr.Addr()) {
-			return nil, fmt.Errorf(
-				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
-				i+1,
-				r.Cidr.String(),
-				network.String(),
-			)
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) {
+				return nil, fmt.Errorf(
+					"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
+					i+1,
+					r.Cidr.String(),
+					network.String(),
+				)
+			}
 		}
 
 		routes[i] = r

+ 39 - 33
overlay/route_test.go

@@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
 	assert.NoError(t, err)
 
 	// test no routes config
-	routes, err := parseRoutes(c, n)
+	routes, err := parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.routes is not an array")
 
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
 
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
+
+	// Not in multiple ranges
+	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
+	routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
+	assert.Nil(t, routes)
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
 
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 2)
 
@@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	assert.NoError(t, err)
 
 	// test no routes config
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 	// no via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
@@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		127, false, nil, 1.0, []string{"1", "2"},
 	} {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
-		routes, err = parseUnsafeRoutes(c, n)
+		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 	}
 
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Equal(t, 0, routes[0].MTU)
 
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 	// bad install
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 
@@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 4)
 
@@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 	}}
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.NoError(t, err)
 	assert.Len(t, routes, 2)
 	routeTree, err := makeRouteTree(l, routes, true)

+ 9 - 9
overlay/tun.go

@@ -11,36 +11,36 @@ import (
 const DefaultMTU = 1300
 
 // TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
 	switch {
 	case c.GetBool("tun.disabled", false):
-		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
+		tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 
 	default:
-		return newTun(c, l, tunCidr, routines > 1)
+		return newTun(c, l, vpnNetworks, routines > 1)
 	}
 }
 
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
-		return newTunFromFd(c, l, *fd, tunCidr)
+	return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+		return newTunFromFd(c, l, *fd, vpnNetworks)
 	}
 }
 
-func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 		return false, nil, nil
 	}
 
-	routes, err := parseRoutes(c, cidr)
+	routes, err := parseRoutes(c, vpnNetworks)
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 	}
 
-	unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
+	unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}

+ 11 - 11
overlay/tun_android.go

@@ -18,14 +18,14 @@ import (
 
 type tun struct {
 	io.ReadWriteCloser
-	fd        int
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	fd          int
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	t := &tun{
 		ReadWriteCloser: file,
 		fd:              deviceFd,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		l:               l,
 	}
 
@@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 }
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 
@@ -66,7 +66,7 @@ func (t tun) Activate() error {
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 208 - 215
overlay/tun_darwin.go

@@ -24,56 +24,62 @@ import (
 
 type tun struct {
 	io.ReadWriteCloser
-	Device     string
-	cidr       netip.Prefix
-	DefaultMTU int
-	Routes     atomic.Pointer[[]Route]
-	routeTree  atomic.Pointer[bart.Table[netip.Addr]]
-	linkAddr   *netroute.LinkAddr
-	l          *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	DefaultMTU  int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	linkAddr    *netroute.LinkAddr
+	l           *logrus.Logger
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	out []byte
 }
 
-type sockaddrCtl struct {
-	scLen      uint8
-	scFamily   uint8
-	ssSysaddr  uint16
-	scID       uint32
-	scUnit     uint32
-	scReserved [5]uint32
-}
-
 type ifReq struct {
-	Name  [16]byte
+	Name  [unix.IFNAMSIZ]byte
 	Flags uint16
 	pad   [8]byte
 }
 
-var sockaddrCtlSize uintptr = 32
-
 const (
-	_SYSPROTO_CONTROL = 2              //define SYSPROTO_CONTROL 2 /* kernel control protocol */
-	_AF_SYS_CONTROL   = 2              //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
-	_PF_SYSTEM        = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
-	_CTLIOCGINFO      = 3227799043     //#define CTLIOCGINFO     _IOWR('N', 3, struct ctl_info)
-	utunControlName   = "com.apple.net.utun_control"
+	_SIOCAIFADDR_IN6 = 2155899162
+	_UTUN_OPT_IFNAME = 2
+	_IN6_IFF_NODAD   = 0x0020
+	_IN6_IFF_SECURED = 0x0400
+	utunControlName  = "com.apple.net.utun_control"
 )
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 	Name [16]byte
 	MTU  int32
 	pad  [8]byte
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+type addrLifetime struct {
+	Expire    float64
+	Preferred float64
+	Vltime    uint32
+	Pltime    uint32
+}
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+}
+
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	name := c.GetString("tun.dev", "")
 	ifIndex := -1
 	if name != "" && name != "utun" {
@@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 		}
 	}
 
-	fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
+	fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
 	if err != nil {
 		return nil, fmt.Errorf("system socket: %v", err)
 	}
 
-	var ctlInfo = &struct {
-		ctlID   uint32
-		ctlName [96]byte
-	}{}
+	var ctlInfo = &unix.CtlInfo{}
+	copy(ctlInfo.Name[:], utunControlName)
 
-	copy(ctlInfo.ctlName[:], utunControlName)
-
-	err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
+	err = unix.IoctlCtlInfo(fd, ctlInfo)
 	if err != nil {
 		return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
 	}
 
-	sc := sockaddrCtl{
-		scLen:     uint8(sockaddrCtlSize),
-		scFamily:  unix.AF_SYSTEM,
-		ssSysaddr: _AF_SYS_CONTROL,
-		scID:      ctlInfo.ctlID,
-		scUnit:    uint32(ifIndex) + 1,
-	}
-
-	_, _, errno := unix.RawSyscall(
-		unix.SYS_CONNECT,
-		uintptr(fd),
-		uintptr(unsafe.Pointer(&sc)),
-		sockaddrCtlSize,
-	)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
+	err = unix.Connect(fd, &unix.SockaddrCtl{
+		ID:   ctlInfo.Id,
+		Unit: uint32(ifIndex) + 1,
+	})
+	if err != nil {
+		return nil, fmt.Errorf("SYS_CONNECT: %v", err)
 	}
 
-	var ifName struct {
-		name [16]byte
-	}
-	ifNameSize := uintptr(len(ifName.name))
-	_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
-		2, // SYSPROTO_CONTROL
-		2, // UTUN_OPT_IFNAME
-		uintptr(unsafe.Pointer(&ifName)),
-		uintptr(unsafe.Pointer(&ifNameSize)), 0)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
+	name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
+	if err != nil {
+		return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
 	}
-	name = string(ifName.name[:ifNameSize-1])
 
-	err = syscall.SetNonblock(fd, true)
+	err = unix.SetNonblock(fd, true)
 	if err != nil {
 		return nil, fmt.Errorf("SetNonblock: %v", err)
 	}
 
-	file := os.NewFile(uintptr(fd), "")
-
 	t := &tun{
-		ReadWriteCloser: file,
+		ReadWriteCloser: os.NewFile(uintptr(fd), ""),
 		Device:          name,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		DefaultMTU:      c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 
@@ -186,16 +167,6 @@ func (t *tun) Close() error {
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 
-	var addr, mask [4]byte
-
-	if !t.cidr.Addr().Is4() {
-		//TODO: IPV6-WORK
-		panic("need ipv6")
-	}
-
-	addr = t.cidr.Addr().As4()
-	copy(mask[:], prefixToMask(t.cidr))
-
 	s, err := unix.Socket(
 		unix.AF_INET,
 		unix.SOCK_DGRAM,
@@ -208,66 +179,18 @@ func (t *tun) Activate() error {
 
 	fd := uintptr(s)
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
-	// Set the device name
-	ifrf := ifReq{Name: devName}
-	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to set tun device name: %s", err)
-	}
-
 	// Set the MTU on the device
 	ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 		return fmt.Errorf("failed to set tun mtu: %v", err)
 	}
 
-	/*
-		// Set the transmit queue length
-		ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
-		if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
-			// If we can't set the queue length nebula will still work but it may lead to packet loss
-			l.WithError(err).Error("Failed to set tun tx queue length")
-		}
-	*/
-
-	// Bring up the interface
-	ifrf.Flags = ifrf.Flags | unix.IFF_UP
-	if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to bring the tun device up: %s", err)
-	}
-
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	// Get the device flags
+	ifrf := ifReq{Name: devName}
+	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+		return fmt.Errorf("failed to get tun flags: %s", err)
 	}
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
 
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	linkAddr, err := getLinkAddr(t.Device)
 	if err != nil {
 		return err
@@ -277,14 +200,18 @@ func (t *tun) Activate() error {
 	}
 	t.linkAddr = linkAddr
 
-	copy(routeAddr.IP[:], addr[:])
-	copy(maskAddr.IP[:], mask[:])
-	err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
-	if err != nil {
-		if errors.Is(err, unix.EEXIST) {
-			err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
+	for _, network := range t.vpnNetworks {
+		if network.Addr().Is4() {
+			err = t.activate4(network)
+			if err != nil {
+				return err
+			}
+		} else {
+			err = t.activate6(network)
+			if err != nil {
+				return err
+			}
 		}
-		return err
 	}
 
 	// Run the interface
@@ -297,8 +224,89 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) activate4(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias4{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		DstAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		MaskAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(network).As4(),
+		},
+	}
+
+	if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun v4 address: %s", err)
+	}
+
+	err = addRoute(network, t.linkAddr)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (t *tun) activate6(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET6,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias6{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   network.Addr().As16(),
+		},
+		PrefixMask: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(network).As16(),
+		},
+		Lifetime: addrLifetime{
+			// never expires
+			Vltime: 0xffffffff,
+			Pltime: 0xffffffff,
+		},
+		//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
+		Flags: _IN6_IFF_NODAD,
+	}
+
+	if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun address: %s", err)
+	}
+
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 }
 
 // Get the LinkAddr for the interface of the given name
-// TODO: Is there an easier way to fetch this when we create the interface?
+// Is there an easier way to fetch this when we create the interface?
 // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
 func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 	rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
@@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 }
 
 func (t *tun) addRoutes(logErrors bool) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		if !r.Cidr.Addr().Is4() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		//TODO: we could avoid the copy
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := addRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 			if errors.Is(err, unix.EEXIST) {
 				t.l.WithField("route", r.Cidr).
@@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
 }
 
 func (t *tun) removeRoutes(routes []Route) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
-
 	for _, r := range routes {
 		if !r.Install {
 			continue
 		}
 
-		if r.Cidr.Addr().Is6() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := delRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
@@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_ADD,
 		Flags:   unix.RTF_UP,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
+
 	_, err = unix.Write(sock, data[:])
 	if err != nil {
 		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 	return nil
 }
 
-func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_DELETE,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
@@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 }
 
 func (t *tun) Read(to []byte) (int, error) {
-
 	buf := make([]byte, len(to)+4)
 
 	n, err := t.ReadWriteCloser.Read(buf)
@@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
 	return n - 4, err
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {
@@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
 
-func prefixToMask(prefix netip.Prefix) []byte {
+func prefixToMask(prefix netip.Prefix) netip.Addr {
 	pLen := 128
 	if prefix.Addr().Is4() {
 		pLen = 32
 	}
-	return net.CIDRMask(prefix.Bits(), pLen)
+
+	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
+	return addr
 }

+ 8 - 8
overlay/tun_disabled.go

@@ -12,8 +12,8 @@ import (
 )
 
 type disabledTun struct {
-	read chan []byte
-	cidr netip.Prefix
+	read        chan []byte
+	vpnNetworks []netip.Prefix
 
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
@@ -21,11 +21,11 @@ type disabledTun struct {
 	l  *logrus.Logger
 }
 
-func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
-		cidr: cidr,
-		read: make(chan []byte, queueLen),
-		l:    l,
+		vpnNetworks: vpnNetworks,
+		read:        make(chan []byte, queueLen),
+		l:           l,
 	}
 
 	if metricsEnabled {
@@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
 	return netip.Addr{}
 }
 
-func (t *disabledTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *disabledTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (*disabledTun) Name() string {

+ 25 - 15
overlay/tun_freebsd.go

@@ -46,12 +46,12 @@ type ifreqDestroy struct {
 }
 
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 	io.ReadWriteCloser
 }
@@ -78,11 +78,11 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
 	var file *os.File
 	var err error
@@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 		ReadWriteCloser: file,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 }
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -195,8 +195,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 10 - 10
overlay/tun_ios.go

@@ -21,20 +21,20 @@ import (
 
 type tun struct {
 	io.ReadWriteCloser
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	t := &tun{
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		ReadWriteCloser: &tunReadCloser{f: file},
 		l:               l,
 	}
@@ -59,7 +59,7 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
 	return tr.f.Close()
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 120 - 78
overlay/tun_linux.go

@@ -11,6 +11,7 @@ import (
 	"os"
 	"strings"
 	"sync/atomic"
+	"time"
 	"unsafe"
 
 	"github.com/gaissmai/bart"
@@ -25,7 +26,7 @@ type tun struct {
 	io.ReadWriteCloser
 	fd          int
 	Device      string
-	cidr        netip.Prefix
+	vpnNetworks []netip.Prefix
 	MaxMTU      int
 	DefaultMTU  int
 	TXQueueLen  int
@@ -40,18 +41,16 @@ type tun struct {
 	l *logrus.Logger
 }
 
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
+
 type ifReq struct {
 	Name  [16]byte
 	Flags uint16
 	pad   [8]byte
 }
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 	Name [16]byte
 	MTU  int32
@@ -64,10 +63,10 @@ type ifreqQLEN struct {
 	pad   [8]byte
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 		return nil, err
 	}
@@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	name := strings.Trim(string(req.Name[:]), "\x00")
 
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 		return nil, err
 	}
@@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	return t, nil
 }
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
 	t := &tun{
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
 		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
 		l:               l,
@@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
 		}
 
 		if oldDefaultMTU != newDefaultMTU {
-			err := t.setDefaultRoute()
-			if err != nil {
-				t.l.Warn(err)
-			} else {
-				t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+			for i := range t.vpnNetworks {
+				err := t.setDefaultRoute(t.vpnNetworks[i])
+				if err != nil {
+					t.l.Warn(err)
+				} else {
+					t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+				}
 			}
 		}
 
@@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 
 func (t *tun) Write(b []byte) (int, error) {
 	var nn int
-	max := len(b)
+	maximum := len(b)
 
 	for {
-		n, err := unix.Write(t.fd, b[nn:max])
+		n, err := unix.Write(t.fd, b[nn:maximum])
 		if n > 0 {
 			nn += n
 		}
@@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
+func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
+	for i := range al {
+		if al[i].Equal(x) {
+			return true
+		}
+	}
+	return false
+}
+
+// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
+func (t *tun) addIPs(link netlink.Link) error {
+	newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
+	for i := range t.vpnNetworks {
+		newAddrs[i] = &netlink.Addr{
+			IPNet: &net.IPNet{
+				IP:   t.vpnNetworks[i].Addr().AsSlice(),
+				Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
+			},
+			Label: t.vpnNetworks[i].Addr().Zone(),
+		}
+	}
+
+	//add all new addresses
+	for i := range newAddrs {
+		//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
+		//AddrReplace still adds new IPs, but if their properties change it will change them as well
+		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
+			return err
+		}
+	}
+
+	//iterate over remainder, remove whoever shouldn't be there
+	al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
+	if err != nil {
+		return fmt.Errorf("failed to get tun address list: %s", err)
+	}
+
+	for i := range al {
+		if hasNetlinkAddr(newAddrs, al[i]) {
+			continue
+		}
+		err = netlink.AddrDel(link, &al[i])
+		if err != nil {
+			t.l.WithError(err).Error("failed to remove address from tun address list")
+		} else {
+			t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
+		}
+	}
+
+	return nil
+}
+
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 
@@ -272,15 +325,8 @@ func (t *tun) Activate() error {
 		t.watchRoutes()
 	}
 
-	var addr, mask [4]byte
-
-	//TODO: IPV6-WORK
-	addr = t.cidr.Addr().As4()
-	tmask := net.CIDRMask(t.cidr.Bits(), 32)
-	copy(mask[:], tmask)
-
 	s, err := unix.Socket(
-		unix.AF_INET,
+		unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
 		unix.SOCK_DGRAM,
 		unix.IPPROTO_IP,
 	)
@@ -289,31 +335,19 @@ func (t *tun) Activate() error {
 	}
 	t.ioctlFd = uintptr(s)
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
 	// Set the device name
 	ifrf := ifReq{Name: devName}
 	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to set tun device name: %s", err)
 	}
 
+	link, err := netlink.LinkByName(t.Device)
+	if err != nil {
+		return fmt.Errorf("failed to get tun device link: %s", err)
+	}
+
+	t.deviceIndex = link.Attrs().Index
+
 	// Setup our default MTU
 	t.setMTU()
 
@@ -324,20 +358,21 @@ func (t *tun) Activate() error {
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 
+	if err = t.addIPs(link); err != nil {
+		return err
+	}
+
 	// Bring up the interface
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP
 	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to bring the tun device up: %s", err)
 	}
 
-	link, err := netlink.LinkByName(t.Device)
-	if err != nil {
-		return fmt.Errorf("failed to get tun device link: %s", err)
-	}
-	t.deviceIndex = link.Attrs().Index
-
-	if err = t.setDefaultRoute(); err != nil {
-		return err
+	//set route MTU
+	for i := range t.vpnNetworks {
+		if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
+			return fmt.Errorf("failed to set default route MTU: %w", err)
+		}
 	}
 
 	// Set the routes
@@ -363,12 +398,10 @@ func (t *tun) setMTU() {
 	}
 }
 
-func (t *tun) setDefaultRoute() error {
-	// Default route
-
+func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
 	dr := &net.IPNet{
-		IP:   t.cidr.Masked().Addr().AsSlice(),
-		Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+		IP:   cidr.Masked().Addr().AsSlice(),
+		Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
 	}
 
 	nr := netlink.Route{
@@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error {
 		MTU:       t.DefaultMTU,
 		AdvMSS:    t.advMSS(Route{}),
 		Scope:     unix.RT_SCOPE_LINK,
-		Src:       net.IP(t.cidr.Addr().AsSlice()),
+		Src:       net.IP(cidr.Addr().AsSlice()),
 		Protocol:  unix.RTPROT_KERNEL,
 		Table:     unix.RT_TABLE_MAIN,
 		Type:      unix.RTN_UNICAST,
 	}
 	err := netlink.RouteReplace(&nr)
 	if err != nil {
-		return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
+		t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
+		//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
+		for i := 0; i < 2; i++ {
+			time.Sleep(100 * time.Millisecond)
+			err = netlink.RouteReplace(&nr)
+			if err == nil {
+				break
+			} else {
+				t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
+			}
+		}
+		if err != nil {
+			return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
+		}
 	}
 
 	return nil
@@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) {
 	}
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
 func (t *tun) Name() string {
 	return t.Device
 }
@@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	//TODO: IPV6-WORK what if not ok?
 	gwAddr, ok := netip.AddrFromSlice(r.Gw)
 	if !ok {
 		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
@@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	}
 
 	gwAddr = gwAddr.Unmap()
-	if !t.cidr.Contains(gwAddr) {
-		// Gateway isn't in our overlay network, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
-		return
+	withinNetworks := false
+	for i := range t.vpnNetworks {
+		if t.vpnNetworks[i].Contains(gwAddr) {
+			withinNetworks = true
+			break
+		}
 	}
-
-	if x := r.Dst.IP.To4(); x == nil {
-		// Nebula only handles ipv4 on the overlay currently
-		t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
+	if !withinNetworks {
+		// Gateway isn't in our overlay network, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
 		return
 	}
 
@@ -563,11 +605,11 @@ func (t *tun) Close() error {
 	}
 
 	if t.ReadWriteCloser != nil {
-		t.ReadWriteCloser.Close()
+		_ = t.ReadWriteCloser.Close()
 	}
 
 	if t.ioctlFd > 0 {
-		os.NewFile(t.ioctlFd, "ioctlFd").Close()
+		_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
 	}
 
 	return nil

+ 28 - 17
overlay/tun_netbsd.go

@@ -27,12 +27,12 @@ type ifreqDestroy struct {
 }
 
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 	io.ReadWriteCloser
 }
@@ -58,13 +58,13 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
 	var file *os.File
 	var err error
@@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 		ReadWriteCloser: file,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 }
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -130,8 +130,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {
@@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

+ 29 - 19
overlay/tun_openbsd.go

@@ -21,12 +21,12 @@ import (
 )
 
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 	io.ReadWriteCloser
 
@@ -42,13 +42,13 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 }
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 		ReadWriteCloser: file,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -138,7 +138,7 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -148,6 +148,16 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
@@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 		if !r.Install {
 			continue
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 17 - 17
overlay/tun_tester.go

@@ -16,19 +16,19 @@ import (
 )
 
 type TestTun struct {
-	Device    string
-	cidr      netip.Prefix
-	Routes    []Route
-	routeTree *bart.Table[netip.Addr]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	Routes      []Route
+	routeTree   *bart.Table[netip.Addr]
+	l           *logrus.Logger
 
 	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
-	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
+	_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
 	if err != nil {
 		return nil, err
 	}
@@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
 	}
 
 	return &TestTun{
-		Device:    c.GetString("tun.dev", ""),
-		cidr:      cidr,
-		Routes:    routes,
-		routeTree: routeTree,
-		l:         l,
-		rxPackets: make(chan []byte, 10),
-		TxPackets: make(chan []byte, 10),
+		Device:      c.GetString("tun.dev", ""),
+		vpnNetworks: vpnNetworks,
+		Routes:      routes,
+		routeTree:   routeTree,
+		l:           l,
+		rxPackets:   make(chan []byte, 10),
+		TxPackets:   make(chan []byte, 10),
 	}, nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 
@@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
 	return nil
 }
 
-func (t *TestTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *TestTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *TestTun) Name() string {

Some files were not shown because too many files changed in this diff