Browse Source

V2 certificate format (#1216)

Co-authored-by: Nate Brown <[email protected]>
Co-authored-by: Jack Doan <[email protected]>
Co-authored-by: brad-defined <[email protected]>
Co-authored-by: Jack Doan <[email protected]>
Nate Brown 4 months ago
parent
commit
d97ed57a19
100 changed files with 7728 additions and 4278 deletions
  1. 2 1
      .gitignore
  2. 1 1
      Makefile
  3. 25 12
      allow_list.go
  4. 34 25
      calculated_remote.go
  5. 61 5
      calculated_remote_test.go
  6. 15 4
      cert/README.md
  7. 52 0
      cert/asn1.go
  8. 1 1
      cert/ca_pool.go
  9. 481 31
      cert/ca_pool_test.go
  10. 50 14
      cert/cert.go
  11. 0 695
      cert/cert_test.go
  12. 225 232
      cert/cert_v1.go
  13. 218 0
      cert/cert_v1_test.go
  14. 37 0
      cert/cert_v2.asn1
  15. 730 0
      cert/cert_v2.go
  16. 267 0
      cert/cert_v2_test.go
  17. 34 12
      cert/errors.go
  18. 141 0
      cert/helper_test.go
  19. 13 7
      cert/pem.go
  20. 105 14
      cert/sign.go
  21. 90 0
      cert/sign_test.go
  22. 51 11
      cert_test/cert.go
  23. 46 22
      cmd/nebula-cert/ca.go
  24. 20 16
      cmd/nebula-cert/ca_test.go
  25. 0 2
      cmd/nebula-cert/keygen_test.go
  26. 0 2
      cmd/nebula-cert/main_test.go
  27. 11 6
      cmd/nebula-cert/print.go
  28. 72 3
      cmd/nebula-cert/print_test.go
  29. 166 66
      cmd/nebula-cert/sign.go
  30. 47 36
      cmd/nebula-cert/sign_test.go
  31. 25 14
      cmd/nebula-cert/verify.go
  32. 4 2
      cmd/nebula-cert/verify_test.go
  33. 0 3
      config/config_test.go
  34. 57 32
      connection_manager.go
  35. 26 29
      connection_manager_test.go
  36. 26 21
      connection_state.go
  37. 26 25
      control.go
  38. 17 17
      control_test.go
  39. 43 22
      control_tester.go
  40. 74 36
      dns_server.go
  41. 20 5
      dns_server_test.go
  42. 304 186
      e2e/handshakes_test.go
  43. 74 37
      e2e/helpers_test.go
  44. 4 3
      e2e/router/hostmap.go
  45. 44 22
      e2e/router/router.go
  46. 12 5
      examples/config.yml
  47. 65 56
      firewall.go
  48. 11 10
      firewall/packet.go
  49. 41 38
      firewall_test.go
  50. 0 1
      go.mod
  51. 0 2
      go.sum
  52. 196 76
      handshake_ix.go
  53. 180 123
      handshake_manager.go
  54. 19 11
      handshake_manager_test.go
  55. 169 94
      hostmap.go
  56. 23 33
      hostmap_test.go
  57. 2 2
      hostmap_tester.go
  58. 31 27
      inside.go
  59. 83 60
      interface.go
  60. 0 2
      iputil/packet.go
  61. 364 240
      lighthouse.go
  62. 173 146
      lighthouse_test.go
  63. 5 24
      main.go
  64. 0 2
      message_metrics.go
  65. 467 171
      nebula.pb.go
  66. 23 9
      nebula.proto
  67. 154 118
      outside.go
  68. 525 17
      outside_test.go
  69. 1 1
      overlay/device.go
  70. 22 12
      overlay/route.go
  71. 39 33
      overlay/route_test.go
  72. 9 9
      overlay/tun.go
  73. 11 11
      overlay/tun_android.go
  74. 208 215
      overlay/tun_darwin.go
  75. 8 8
      overlay/tun_disabled.go
  76. 25 15
      overlay/tun_freebsd.go
  77. 10 10
      overlay/tun_ios.go
  78. 120 78
      overlay/tun_linux.go
  79. 28 17
      overlay/tun_netbsd.go
  80. 29 19
      overlay/tun_openbsd.go
  81. 17 17
      overlay/tun_tester.go
  82. 0 208
      overlay/tun_water_windows.go
  83. 240 13
      overlay/tun_windows.go
  84. 0 252
      overlay/tun_wintun_windows.go
  85. 6 6
      overlay/user.go
  86. 328 68
      pki.go
  87. 166 130
      relay_manager.go
  88. 51 35
      remote_list.go
  89. 35 20
      remote_list_test.go
  90. 2 2
      service/service.go
  91. 3 3
      service/service_test.go
  92. 73 88
      ssh.go
  93. 1 7
      sshd/command.go
  94. 1 3
      sshd/server.go
  95. 1 12
      sshd/session.go
  96. 2 2
      test/tun.go
  97. 4 4
      timeout_test.go
  98. 3 12
      udp/conn.go
  99. 0 10
      udp/temp.go
  100. 3 19
      udp/udp_generic.go

+ 2 - 1
.gitignore

@@ -5,7 +5,8 @@
 /nebula-darwin
 /nebula.exe
 /nebula-cert.exe
-/coverage.out
+**/coverage.out
+**/cover.out
 /cpu.pprof
 /build
 /*.tar.gz

+ 1 - 1
Makefile

@@ -196,7 +196,7 @@ bench-cpu-long:
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 
-proto: nebula.pb.go cert/cert.pb.go
+proto: nebula.pb.go cert/cert_v1.pb.go
 
 nebula.pb.go: nebula.proto .FORCE
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster

+ 25 - 12
allow_list.go

@@ -128,7 +128,6 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
 
-		// TODO: should we error on duplicate CIDRs in the config?
 		tree.Insert(ipNet, value)
 
 		maskBits := ipNet.Bits()
@@ -251,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
 	return remoteAllowRanges, nil
 }
 
-func (al *AllowList) Allow(ip netip.Addr) bool {
+func (al *AllowList) Allow(addr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
 
-	result, _ := al.cidrTree.Lookup(ip)
+	result, _ := al.cidrTree.Lookup(addr)
 	return result
 }
 
-func (al *LocalAllowList) Allow(ip netip.Addr) bool {
+func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
 func (al *LocalAllowList) AllowName(name string) bool {
@@ -282,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 }
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
+func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(vpnAddr)
 }
 
-func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
-	if !al.getInsideAllowList(vpnIp).Allow(ip) {
+func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
+	if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
 		return false
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
+func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
+	if !al.AllowList.Allow(udpAddr) {
+		return false
+	}
+
+	for _, vpnAddr := range vpnAddrs {
+		if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
-		inside, ok := al.insideAllowLists.Lookup(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnAddr)
 		if ok {
 			return inside
 		}

+ 34 - 25
calculated_remote.go

@@ -21,7 +21,11 @@ type calculatedRemote struct {
 	port  uint32
 }
 
-func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
+		return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
+	}
+
 	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
@@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 
-func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
-	// Combine the masked bytes of the "mask" IP with the unmasked bytes
-	// of the overlay IP
-	if c.ipNet.Addr().Is4() {
-		return c.apply4(ip)
-	}
-	return c.apply6(ip)
-}
-
-func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK this can be less crappy
+func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
+	// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
 	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
 	mask := binary.BigEndian.Uint32(maskb[:])
 
 	b := c.mask.Addr().As4()
-	maskIp := binary.BigEndian.Uint32(b[:])
+	maskAddr := binary.BigEndian.Uint32(b[:])
 
-	b = ip.As4()
-	intIp := binary.BigEndian.Uint32(b[:])
+	b = addr.As4()
+	intAddr := binary.BigEndian.Uint32(b[:])
 
-	return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+	return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
 }
 
-func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK
-	panic("Can not calculate ipv6 remote addresses")
+func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
+	mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	maskAddr := c.mask.Addr().As16()
+	calcAddr := addr.As16()
+
+	ap := V6AddrPort{Port: c.port}
+
+	maskb := binary.BigEndian.Uint64(mask[:8])
+	maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
+	calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
+	ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	maskb = binary.BigEndian.Uint64(mask[8:])
+	maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
+	calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
+	ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	return &ap
 }
 
 func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 
-		//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
-		entry, err := newCalculatedRemotesListFromConfig(rawValue)
+		entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
@@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 	return calculatedRemotes, nil
 }
 
-func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
+func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
 	rawList, ok := raw.([]any)
 	if !ok {
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 
 	var l []*calculatedRemote
 	for _, e := range rawList {
-		c, err := newCalculatedRemotesEntryFromConfig(e)
+		c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
 		if err != nil {
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 		}
@@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 	return l, nil
 }
 
-func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
+func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
 	rawMap, ok := raw.(map[any]any)
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
@@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 
-	return newCalculatedRemote(maskCidr, port)
+	return newCalculatedRemote(cidr, maskCidr, port)
 }

+ 61 - 5
calculated_remote_test.go

@@ -9,10 +9,9 @@ import (
 )
 
 func TestCalculatedRemoteApply(t *testing.T) {
-	ipNet, err := netip.ParsePrefix("192.168.1.0/24")
-	require.NoError(t, err)
-
-	c, err := newCalculatedRemote(ipNet, 4242)
+	// Test v4 addresses
+	ipNet := netip.MustParsePrefix("192.168.1.0/24")
+	c, err := newCalculatedRemote(ipNet, ipNet, 4242)
 	require.NoError(t, err)
 
 	input, err := netip.ParseAddr("10.0.10.182")
@@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	expected, err := netip.ParseAddr("192.168.1.182")
 	assert.NoError(t, err)
 
-	assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
+	assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
+
+	// Test v6 addresses
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+}
+
+func Test_newCalculatedRemote(t *testing.T) {
+	c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
 }

+ 15 - 4
cert/README.md

@@ -2,14 +2,25 @@
 
 This is a library for interacting with `nebula` style certificates and authorities.
 
-A `protobuf` definition of the certificate format is also included
+There are now 2 versions of `nebula` certificates:
 
-### Compiling the protobuf definition
+## v1
 
-Make sure you have `protoc` installed.
+This version is deprecated.
+
+A `protobuf` definition of the certificate format is included at `cert_v1.proto`
+
+To compile the definition you will need `protoc` installed.
 
 To compile for `go` with the same version of protobuf specified in go.mod:
 
 ```bash
-make
+make proto
 ```
+
+## v2
+
+This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
+future certificate changes better than v1.
+
+`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

+ 52 - 0
cert/asn1.go

@@ -0,0 +1,52 @@
+package cert
+
+import (
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+)
+
+// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
+// https://github.com/golang/go/issues/64811#issuecomment-1944446920
+func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0] > 0
+		return true
+	}
+
+	return false
+}
+
+// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
+// Similar issue as with readOptionalASN1Boolean
+func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0]
+		return true
+	}
+
+	return false
+}

+ 1 - 1
cert/ca_pool.go

@@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
 		return signer, nil
 	}
 
-	return nil, fmt.Errorf("could not find ca for the certificate")
+	return nil, ErrCaNotFound
 }
 
 // GetFingerprints returns an array of trusted CA fingerprints

+ 481 - 31
cert/ca_pool_test.go

@@ -1,7 +1,9 @@
 package cert
 
 import (
+	"net/netip"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 )
@@ -10,15 +12,15 @@ func TestNewCAPoolFromBytes(t *testing.T) {
 	noNewLines := `
 # Current provisional, Remove once everything moves over to the real root.
 -----BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
 -----END NEBULA CERTIFICATE-----
 # root-ca01
 -----BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
 -----END NEBULA CERTIFICATE-----
 `
 
@@ -26,18 +28,18 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
 # Current provisional, Remove once everything moves over to the real root.
 
 -----BEGIN NEBULA CERTIFICATE-----
-CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL
-vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv
-bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
+Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+
+PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf
+2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ==
 -----END NEBULA CERTIFICATE-----
 
 # root-ca01
 
 
 -----BEGIN NEBULA CERTIFICATE-----
-CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG
-BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
-8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF
+CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br
+BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye
+rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA==
 -----END NEBULA CERTIFICATE-----
 
 `
@@ -45,65 +47,513 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf
 	expired := `
 # expired certificate
 -----BEGIN NEBULA CERTIFICATE-----
-CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4
-vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie
-WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs=
+CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA
+7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8
+Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0=
 -----END NEBULA CERTIFICATE-----
 `
 
 	p256 := `
 # p256 certificate
 -----BEGIN NEBULA CERTIFICATE-----
-CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
-6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H
-76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
-IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
+CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp
+k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq
+75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA==
 -----END NEBULA CERTIFICATE-----
 `
 
 	rootCA := certificateV1{
 		details: detailsV1{
-			Name: "nebula root ca",
+			name: "nebula root ca",
 		},
 	}
 
 	rootCA01 := certificateV1{
 		details: detailsV1{
-			Name: "nebula root ca 01",
+			name: "nebula root ca 01",
 		},
 	}
 
 	rootCAP256 := certificateV1{
 		details: detailsV1{
-			Name: "nebula P256 test",
+			name: "nebula P256 test",
 		},
 	}
 
 	p, err := NewCAPoolFromPEM([]byte(noNewLines))
 	assert.Nil(t, err)
-	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
-	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
+	assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
 
 	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
 	assert.Nil(t, err)
-	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
-	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
+	assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
 
 	// expired cert, no valid certs
 	ppp, err := NewCAPoolFromPEM([]byte(expired))
 	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
+	assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
 
 	// expired cert, with valid certs
 	pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
 	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
-	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
-	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
+	assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
+	assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired")
 	assert.Equal(t, len(pppp.CAs), 3)
 
 	ppppp, err := NewCAPoolFromPEM([]byte(p256))
 	assert.Nil(t, err)
-	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name)
+	assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
 	assert.Equal(t, len(ppppp.CAs), 1)
 }
+
+func TestCertificateV1_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV1_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_VerifyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+
+	caPool := NewCAPool()
+	assert.NoError(t, caPool.AddCA(ca))
+
+	f, err := c.Fingerprint()
+	assert.Nil(t, err)
+	caPool.BlocklistFingerprint(f)
+
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.EqualError(t, err, "certificate is in the block list")
+
+	caPool.ResetCertBlocklist()
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
+	assert.EqualError(t, err, "root certificate is expired")
+
+	assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	})
+
+	// Test group assertion
+	ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool = NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
+		NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
+	})
+
+	c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify_IPs(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}
+
+func TestCertificateV2_Verify_Subnets(t *testing.T) {
+	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
+	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+
+	caPem, err := ca.MarshalPEM()
+	assert.Nil(t, err)
+
+	caPool := NewCAPool()
+	b, err := caPool.AddCAFromPEM(caPem)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+
+	// ip is outside the network
+	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
+	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is outside the network reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip is within the network but mask is outside reversed order of above
+	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
+	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
+	assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() {
+		NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	})
+
+	// ip and mask are within the network
+	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
+	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+
+	// Exact matches reversed with just 1
+	c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
+	assert.Nil(t, err)
+	_, err = caPool.VerifyCertificate(time.Now(), c)
+	assert.Nil(t, err)
+}

+ 50 - 14
cert/cert.go

@@ -1,15 +1,17 @@
 package cert
 
 import (
+	"fmt"
 	"net/netip"
 	"time"
 )
 
-type Version int
+type Version uint8
 
 const (
-	Version1 Version = 1
-	Version2 Version = 2
+	VersionPre1 Version = 0
+	Version1    Version = 1
+	Version2    Version = 2
 )
 
 type Certificate interface {
@@ -107,23 +109,57 @@ type CachedCertificate struct {
 	signerFingerprint string
 }
 
-// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate.
-func UnmarshalCertificate(b []byte) (Certificate, error) {
-	c, err := unmarshalCertificateV1(b, true)
-	if err != nil {
-		return nil, err
-	}
-	return c, nil
+func (cc *CachedCertificate) String() string {
+	return cc.Certificate.String()
 }
 
-// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
+// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
 // Handshakes save space by placing the peers public key in a different part of the packet, we have to
 // reassemble the actual certificate structure with that in mind.
-func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) {
-	c, err := unmarshalCertificateV1(b, false)
+func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
+	if publicKey == nil {
+		return nil, ErrNoPeerStaticKey
+	}
+
+	if rawCertBytes == nil {
+		return nil, ErrNoPayload
+	}
+
+	c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
+	if err != nil {
+		return nil, fmt.Errorf("error unmarshaling cert: %w", err)
+	}
+
+	cc, err := caPool.VerifyCertificate(time.Now(), c)
+	if err != nil {
+		return nil, fmt.Errorf("certificate validation failed: %w", err)
+	}
+
+	return cc, nil
+}
+
+func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
+	var c Certificate
+	var err error
+
+	switch v {
+	// Implementations must ensure the result is a valid cert!
+	case VersionPre1, Version1:
+		c, err = unmarshalCertificateV1(b, publicKey)
+	case Version2:
+		c, err = unmarshalCertificateV2(b, publicKey, curve)
+	default:
+		//TODO: CERT-V2 make a static var
+		return nil, fmt.Errorf("unknown certificate version %d", v)
+	}
+
 	if err != nil {
 		return nil, err
 	}
-	c.details.PublicKey = publicKey
+
+	if c.Curve() != curve {
+		return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
+	}
+
 	return c, nil
 }

+ 0 - 695
cert/cert_test.go

@@ -1,695 +0,0 @@
-package cert
-
-import (
-	"crypto/ecdh"
-	"crypto/ecdsa"
-	"crypto/elliptic"
-	"crypto/rand"
-	"fmt"
-	"io"
-	"net/netip"
-	"testing"
-	"time"
-
-	"github.com/slackhq/nebula/test"
-	"github.com/stretchr/testify/assert"
-	"golang.org/x/crypto/curve25519"
-	"golang.org/x/crypto/ed25519"
-)
-
-func TestMarshalingNebulaCertificate(t *testing.T) {
-	before := time.Now().Add(time.Second * -60).Round(time.Second)
-	after := time.Now().Add(time.Second * 60).Round(time.Second)
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := certificateV1{
-		details: detailsV1{
-			Name: "testing",
-			Ips: []netip.Prefix{
-				mustParsePrefixUnmapped("10.1.1.1/24"),
-				mustParsePrefixUnmapped("10.1.1.2/16"),
-			},
-			Subnets: []netip.Prefix{
-				mustParsePrefixUnmapped("9.1.1.2/24"),
-				mustParsePrefixUnmapped("9.1.1.3/16"),
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.Marshal()
-	assert.Nil(t, err)
-	//t.Log("Cert size:", len(b))
-
-	nc2, err := unmarshalCertificateV1(b, true)
-	assert.Nil(t, err)
-
-	assert.Equal(t, nc.signature, nc2.Signature())
-	assert.Equal(t, nc.details.Name, nc2.Name())
-	assert.Equal(t, nc.details.NotBefore, nc2.NotBefore())
-	assert.Equal(t, nc.details.NotAfter, nc2.NotAfter())
-	assert.Equal(t, nc.details.PublicKey, nc2.PublicKey())
-	assert.Equal(t, nc.details.IsCA, nc2.IsCA())
-
-	assert.Equal(t, nc.details.Ips, nc2.Networks())
-	assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks())
-
-	assert.Equal(t, nc.details.Groups, nc2.Groups())
-}
-
-//func TestNebulaCertificate_Sign(t *testing.T) {
-//	before := time.Now().Add(time.Second * -60).Round(time.Second)
-//	after := time.Now().Add(time.Second * 60).Round(time.Second)
-//	pubKey := []byte("1234567890abcedfghij1234567890ab")
-//
-//	nc := certificateV1{
-//		details: detailsV1{
-//			Name: "testing",
-//			Ips: []netip.Prefix{
-//				mustParsePrefixUnmapped("10.1.1.1/24"),
-//				mustParsePrefixUnmapped("10.1.1.2/16"),
-//				//TODO: netip cant do it
-//				//{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-//			},
-//			Subnets: []netip.Prefix{
-//				//TODO: netip cant do it
-//				//{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-//				mustParsePrefixUnmapped("9.1.1.2/24"),
-//				mustParsePrefixUnmapped("9.1.1.3/24"),
-//			},
-//			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-//			NotBefore: before,
-//			NotAfter:  after,
-//			PublicKey: pubKey,
-//			IsCA:      false,
-//			Issuer:    "1234567890abcedfghij1234567890ab",
-//		},
-//	}
-//
-//	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-//	assert.Nil(t, err)
-//	assert.False(t, nc.CheckSignature(pub))
-//	assert.Nil(t, nc.Sign(Curve_CURVE25519, priv))
-//	assert.True(t, nc.CheckSignature(pub))
-//
-//	_, err = nc.Marshal()
-//	assert.Nil(t, err)
-//	//t.Log("Cert size:", len(b))
-//}
-
-//func TestNebulaCertificate_SignP256(t *testing.T) {
-//	before := time.Now().Add(time.Second * -60).Round(time.Second)
-//	after := time.Now().Add(time.Second * 60).Round(time.Second)
-//	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
-//
-//	nc := certificateV1{
-//		details: detailsV1{
-//			Name: "testing",
-//			Ips: []netip.Prefix{
-//				mustParsePrefixUnmapped("10.1.1.1/24"),
-//				mustParsePrefixUnmapped("10.1.1.2/16"),
-//				//TODO: netip no can do
-//				//{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-//			},
-//			Subnets: []netip.Prefix{
-//				//TODO: netip bad
-//				//{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-//				mustParsePrefixUnmapped("9.1.1.2/24"),
-//				mustParsePrefixUnmapped("9.1.1.3/16"),
-//			},
-//			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-//			NotBefore: before,
-//			NotAfter:  after,
-//			PublicKey: pubKey,
-//			IsCA:      false,
-//			Curve:     Curve_P256,
-//			Issuer:    "1234567890abcedfghij1234567890ab",
-//		},
-//	}
-//
-//	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-//	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-//	rawPriv := priv.D.FillBytes(make([]byte, 32))
-//
-//	assert.Nil(t, err)
-//	assert.False(t, nc.CheckSignature(pub))
-//	assert.Nil(t, nc.Sign(Curve_P256, rawPriv))
-//	assert.True(t, nc.CheckSignature(pub))
-//
-//	_, err = nc.Marshal()
-//	assert.Nil(t, err)
-//	//t.Log("Cert size:", len(b))
-//}
-
-func TestNebulaCertificate_Expired(t *testing.T) {
-	nc := certificateV1{
-		details: detailsV1{
-			NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
-			NotAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
-		},
-	}
-
-	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
-	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
-	assert.False(t, nc.Expired(time.Now()))
-}
-
-func TestNebulaCertificate_MarshalJSON(t *testing.T) {
-	time.Local = time.UTC
-	pubKey := []byte("1234567890abcedfghij1234567890ab")
-
-	nc := certificateV1{
-		details: detailsV1{
-			Name: "testing",
-			Ips: []netip.Prefix{
-				mustParsePrefixUnmapped("10.1.1.1/24"),
-				mustParsePrefixUnmapped("10.1.1.2/16"),
-			},
-			Subnets: []netip.Prefix{
-				mustParsePrefixUnmapped("9.1.1.2/24"),
-				mustParsePrefixUnmapped("9.1.1.3/16"),
-			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
-			NotAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
-		},
-		signature: []byte("1234567890abcedfghij1234567890ab"),
-	}
-
-	b, err := nc.MarshalJSON()
-	assert.Nil(t, err)
-	assert.Equal(
-		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
-		string(b),
-	)
-}
-
-func TestNebulaCertificate_Verify(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	assert.NoError(t, caPool.AddCA(ca))
-
-	f, err := c.Fingerprint()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
-	assert.EqualError(t, err, "certificate is valid before the signing certificate")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
-	assert.Empty(t, b)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	assert.NoError(t, caPool.AddCA(ca))
-
-	f, err := c.Fingerprint()
-	assert.Nil(t, err)
-	caPool.BlocklistFingerprint(f)
-
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.EqualError(t, err, "certificate is in the block list")
-
-	caPool.ResetCertBlocklist()
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
-	assert.EqualError(t, err, "root certificate is expired")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
-	assert.EqualError(t, err, "certificate is valid before the signing certificate")
-
-	// Test group assertion
-	ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalPEM()
-	assert.Nil(t, err)
-
-	caPool = NewCAPool()
-	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
-	assert.Empty(t, b)
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"})
-	assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad")
-
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_IPs(t *testing.T) {
-	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
-	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
-	assert.Empty(t, b)
-
-	// ip is outside the network
-	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
-	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
-	assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
-	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
-	assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
-	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
-	assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
-	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
-	assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
-	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_Verify_Subnets(t *testing.T) {
-	caIp1 := mustParsePrefixUnmapped("10.0.0.0/16")
-	caIp2 := mustParsePrefixUnmapped("192.168.0.0/24")
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-
-	caPem, err := ca.MarshalPEM()
-	assert.Nil(t, err)
-
-	caPool := NewCAPool()
-	b, err := caPool.AddCAFromPEM(caPem)
-	assert.NoError(t, err)
-	assert.Empty(t, b)
-
-	// ip is outside the network
-	cIp1 := mustParsePrefixUnmapped("10.1.0.0/24")
-	cIp2 := mustParsePrefixUnmapped("192.168.0.1/16")
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is outside the network reversed order of above
-	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
-	cIp2 = mustParsePrefixUnmapped("10.1.0.0/24")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24")
-
-	// ip is within the network but mask is outside
-	cIp1 = mustParsePrefixUnmapped("10.0.1.0/15")
-	cIp2 = mustParsePrefixUnmapped("192.168.0.1/24")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip is within the network but mask is outside reversed order of above
-	cIp1 = mustParsePrefixUnmapped("192.168.0.1/24")
-	cIp2 = mustParsePrefixUnmapped("10.0.1.0/15")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15")
-
-	// ip and mask are within the network
-	cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
-	cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	// Exact matches
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	// Exact matches reversed
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-
-	// Exact matches reversed with just 1
-	c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
-	assert.Nil(t, err)
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	assert.Nil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil)
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil)
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := x25519Keypair()
-	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
-	assert.NotNil(t, err)
-}
-
-func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) {
-	ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil)
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey)
-	assert.Nil(t, err)
-
-	_, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil)
-	assert.Nil(t, err)
-	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
-	assert.NotNil(t, err)
-
-	c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil)
-	err = c.VerifyPrivateKey(Curve_P256, priv)
-	assert.Nil(t, err)
-
-	_, priv2 := p256Keypair()
-	err = c.VerifyPrivateKey(Curve_P256, priv2)
-	assert.NotNil(t, err)
-}
-
-func appendByteSlices(b ...[]byte) []byte {
-	retSlice := []byte{}
-	for _, v := range b {
-		retSlice = append(retSlice, v...)
-	}
-	return retSlice
-}
-
-// Ensure that upgrading the protobuf library does not change how certificates
-// are marshalled, since this would break signature verification
-//TODO: since netip cant represent 255.0.255.0 netmask we can't verify the old certs are ok
-//func TestMarshalingNebulaCertificateConsistency(t *testing.T) {
-//	before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
-//	after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC)
-//	pubKey := []byte("1234567890abcedfghij1234567890ab")
-//
-//	nc := certificateV1{
-//		details: detailsV1{
-//			Name: "testing",
-//			Ips: []netip.Prefix{
-//				mustParsePrefixUnmapped("10.1.1.1/24"),
-//				mustParsePrefixUnmapped("10.1.1.2/16"),
-//				//TODO: netip bad
-//				//{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-//			},
-//			Subnets: []netip.Prefix{
-//				//TODO: netip bad
-//				//{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))},
-//				mustParsePrefixUnmapped("9.1.1.2/24"),
-//				mustParsePrefixUnmapped("9.1.1.3/16"),
-//			},
-//			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-//			NotBefore: before,
-//			NotAfter:  after,
-//			PublicKey: pubKey,
-//			IsCA:      false,
-//			Issuer:    "1234567890abcedfghij1234567890ab",
-//		},
-//		signature: []byte("1234567890abcedfghij1234567890ab"),
-//	}
-//
-//	b, err := nc.Marshal()
-//	assert.Nil(t, err)
-//	//t.Log("Cert size:", len(b))
-//	assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
-//
-//	b, err = proto.Marshal(nc.getRawDetails())
-//	assert.Nil(t, err)
-//	//t.Log("Raw cert size:", len(b))
-//	assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
-//}
-
-func TestNebulaCertificate_Copy(t *testing.T) {
-	ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
-	assert.Nil(t, err)
-
-	c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
-	assert.Nil(t, err)
-	cc := c.Copy()
-
-	test.AssertDeepCopyEqual(t, c, cc)
-}
-
-func TestUnmarshalNebulaCertificate(t *testing.T) {
-	// Test that we don't panic with an invalid certificate (#332)
-	data := []byte("\x98\x00\x00")
-	_, err := unmarshalCertificateV1(data, true)
-	assert.EqualError(t, err, "encoded Details was nil")
-}
-
-func newTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	tbs := &TBSCertificate{
-		Version:   Version1,
-		Name:      "test ca",
-		IsCA:      true,
-		NotBefore: time.Unix(before.Unix(), 0),
-		NotAfter:  time.Unix(after.Unix(), 0),
-		PublicKey: pub,
-	}
-
-	if len(ips) > 0 {
-		tbs.Networks = ips
-	}
-
-	if len(subnets) > 0 {
-		tbs.UnsafeNetworks = subnets
-	}
-
-	if len(groups) > 0 {
-		tbs.Groups = groups
-	}
-
-	nc, err := tbs.Sign(nil, Curve_CURVE25519, priv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, priv, nil
-}
-
-func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) {
-	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
-	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
-	rawPriv := priv.D.FillBytes(make([]byte, 32))
-
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	tbs := &TBSCertificate{
-		Version:   Version1,
-		Name:      "test ca",
-		IsCA:      true,
-		NotBefore: time.Unix(before.Unix(), 0),
-		NotAfter:  time.Unix(after.Unix(), 0),
-		PublicKey: pub,
-		Curve:     Curve_P256,
-	}
-
-	if len(ips) > 0 {
-		tbs.Networks = ips
-	}
-
-	if len(subnets) > 0 {
-		tbs.UnsafeNetworks = subnets
-	}
-
-	if len(groups) > 0 {
-		tbs.Groups = groups
-	}
-
-	nc, err := tbs.Sign(nil, Curve_P256, rawPriv)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-	return nc, pub, rawPriv, nil
-}
-
-func newTestCert(ca Certificate, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) {
-	if before.IsZero() {
-		before = time.Now().Add(time.Second * -60).Round(time.Second)
-	}
-	if after.IsZero() {
-		after = time.Now().Add(time.Second * 60).Round(time.Second)
-	}
-
-	if len(groups) == 0 {
-		groups = []string{"test-group1", "test-group2", "test-group3"}
-	}
-
-	if len(ips) == 0 {
-		ips = []netip.Prefix{
-			mustParsePrefixUnmapped("10.1.1.1/24"),
-			mustParsePrefixUnmapped("10.1.1.2/16"),
-		}
-	}
-
-	if len(subnets) == 0 {
-		subnets = []netip.Prefix{
-			mustParsePrefixUnmapped("9.1.1.2/24"),
-			mustParsePrefixUnmapped("9.1.1.3/16"),
-		}
-	}
-
-	var pub, rawPriv []byte
-
-	switch ca.Curve() {
-	case Curve_CURVE25519:
-		pub, rawPriv = x25519Keypair()
-	case Curve_P256:
-		pub, rawPriv = p256Keypair()
-	default:
-		return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Curve())
-	}
-
-	tbs := &TBSCertificate{
-		Version:        Version1,
-		Name:           "testing",
-		Networks:       ips,
-		UnsafeNetworks: subnets,
-		Groups:         groups,
-		IsCA:           false,
-		NotBefore:      time.Unix(before.Unix(), 0),
-		NotAfter:       time.Unix(after.Unix(), 0),
-		PublicKey:      pub,
-		Curve:          ca.Curve(),
-	}
-
-	nc, err := tbs.Sign(ca, ca.Curve(), key)
-	if err != nil {
-		return nil, nil, nil, err
-	}
-
-	return nc, pub, rawPriv, nil
-}
-
-func x25519Keypair() ([]byte, []byte) {
-	privkey := make([]byte, 32)
-	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
-		panic(err)
-	}
-
-	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
-	if err != nil {
-		panic(err)
-	}
-
-	return pubkey, privkey
-}
-
-func p256Keypair() ([]byte, []byte) {
-	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
-	if err != nil {
-		panic(err)
-	}
-	pubkey := privkey.PublicKey()
-	return pubkey.Bytes(), privkey.Bytes()
-}
-
-func mustParsePrefixUnmapped(s string) netip.Prefix {
-	prefix := netip.MustParsePrefix(s)
-	return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
-}

+ 225 - 232
cert/cert_v1.go

@@ -6,19 +6,16 @@ import (
 	"crypto/ecdsa"
 	"crypto/ed25519"
 	"crypto/elliptic"
-	"crypto/rand"
 	"crypto/sha256"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/json"
 	"encoding/pem"
 	"fmt"
-	"math/big"
 	"net"
 	"net/netip"
 	"time"
 
-	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 	"google.golang.org/protobuf/proto"
 )
@@ -31,71 +28,71 @@ type certificateV1 struct {
 }
 
 type detailsV1 struct {
-	Name      string
-	Ips       []netip.Prefix
-	Subnets   []netip.Prefix
-	Groups    []string
-	NotBefore time.Time
-	NotAfter  time.Time
-	PublicKey []byte
-	IsCA      bool
-	Issuer    string
-
-	Curve Curve
+	name           string
+	networks       []netip.Prefix
+	unsafeNetworks []netip.Prefix
+	groups         []string
+	notBefore      time.Time
+	notAfter       time.Time
+	publicKey      []byte
+	isCA           bool
+	issuer         string
+
+	curve Curve
 }
 
 type m map[string]interface{}
 
-func (nc *certificateV1) Version() Version {
+func (c *certificateV1) Version() Version {
 	return Version1
 }
 
-func (nc *certificateV1) Curve() Curve {
-	return nc.details.Curve
+func (c *certificateV1) Curve() Curve {
+	return c.details.curve
 }
 
-func (nc *certificateV1) Groups() []string {
-	return nc.details.Groups
+func (c *certificateV1) Groups() []string {
+	return c.details.groups
 }
 
-func (nc *certificateV1) IsCA() bool {
-	return nc.details.IsCA
+func (c *certificateV1) IsCA() bool {
+	return c.details.isCA
 }
 
-func (nc *certificateV1) Issuer() string {
-	return nc.details.Issuer
+func (c *certificateV1) Issuer() string {
+	return c.details.issuer
 }
 
-func (nc *certificateV1) Name() string {
-	return nc.details.Name
+func (c *certificateV1) Name() string {
+	return c.details.name
 }
 
-func (nc *certificateV1) Networks() []netip.Prefix {
-	return nc.details.Ips
+func (c *certificateV1) Networks() []netip.Prefix {
+	return c.details.networks
 }
 
-func (nc *certificateV1) NotAfter() time.Time {
-	return nc.details.NotAfter
+func (c *certificateV1) NotAfter() time.Time {
+	return c.details.notAfter
 }
 
-func (nc *certificateV1) NotBefore() time.Time {
-	return nc.details.NotBefore
+func (c *certificateV1) NotBefore() time.Time {
+	return c.details.notBefore
 }
 
-func (nc *certificateV1) PublicKey() []byte {
-	return nc.details.PublicKey
+func (c *certificateV1) PublicKey() []byte {
+	return c.details.publicKey
 }
 
-func (nc *certificateV1) Signature() []byte {
-	return nc.signature
+func (c *certificateV1) Signature() []byte {
+	return c.signature
 }
 
-func (nc *certificateV1) UnsafeNetworks() []netip.Prefix {
-	return nc.details.Subnets
+func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
 }
 
-func (nc *certificateV1) Fingerprint() (string, error) {
-	b, err := nc.Marshal()
+func (c *certificateV1) Fingerprint() (string, error) {
+	b, err := c.Marshal()
 	if err != nil {
 		return "", err
 	}
@@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) {
 	return hex.EncodeToString(sum[:]), nil
 }
 
-func (nc *certificateV1) CheckSignature(key []byte) bool {
-	b, err := proto.Marshal(nc.getRawDetails())
+func (c *certificateV1) CheckSignature(key []byte) bool {
+	b, err := proto.Marshal(c.getRawDetails())
 	if err != nil {
 		return false
 	}
-	switch nc.details.Curve {
+	switch c.details.curve {
 	case Curve_CURVE25519:
-		return ed25519.Verify(key, b, nc.signature)
+		return ed25519.Verify(key, b, c.signature)
 	case Curve_P256:
 		x, y := elliptic.Unmarshal(elliptic.P256(), key)
 		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
 		hashed := sha256.Sum256(b)
-		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
 	default:
 		return false
 	}
 }
 
-func (nc *certificateV1) Expired(t time.Time) bool {
-	return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t)
+func (c *certificateV1) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
 }
 
-func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
-	if curve != nc.details.Curve {
+func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.details.curve {
 		return fmt.Errorf("curve in cert and private key supplied don't match")
 	}
-	if nc.details.IsCA {
+	if c.details.isCA {
 		switch curve {
 		case Curve_CURVE25519:
 			// the call to PublicKey below will panic slice bounds out of range otherwise
@@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
 			}
 
-			if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
+			if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
 				return fmt.Errorf("public key in cert and private key supplied don't match")
 			}
 		case Curve_P256:
@@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 				return fmt.Errorf("cannot parse private key as P256: %w", err)
 			}
 			pub := privkey.PublicKey().Bytes()
-			if !bytes.Equal(pub, nc.details.PublicKey) {
+			if !bytes.Equal(pub, c.details.publicKey) {
 				return fmt.Errorf("public key in cert and private key supplied don't match")
 			}
 		default:
@@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 	default:
 		return fmt.Errorf("invalid curve: %s", curve)
 	}
-	if !bytes.Equal(pub, nc.details.PublicKey) {
+	if !bytes.Equal(pub, c.details.publicKey) {
 		return fmt.Errorf("public key in cert and private key supplied don't match")
 	}
 
@@ -181,173 +178,219 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 }
 
 // getRawDetails marshals the raw details into protobuf ready struct
-func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
+func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
 	rd := &RawNebulaCertificateDetails{
-		Name:      nc.details.Name,
-		Groups:    nc.details.Groups,
-		NotBefore: nc.details.NotBefore.Unix(),
-		NotAfter:  nc.details.NotAfter.Unix(),
-		PublicKey: make([]byte, len(nc.details.PublicKey)),
-		IsCA:      nc.details.IsCA,
-		Curve:     nc.details.Curve,
+		Name:      c.details.name,
+		Groups:    c.details.groups,
+		NotBefore: c.details.notBefore.Unix(),
+		NotAfter:  c.details.notAfter.Unix(),
+		PublicKey: make([]byte, len(c.details.publicKey)),
+		IsCA:      c.details.isCA,
+		Curve:     c.details.curve,
 	}
 
-	for _, ipNet := range nc.details.Ips {
+	for _, ipNet := range c.details.networks {
 		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
 		rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
 	}
 
-	for _, ipNet := range nc.details.Subnets {
+	for _, ipNet := range c.details.unsafeNetworks {
 		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
 		rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
 	}
 
-	copy(rd.PublicKey, nc.details.PublicKey[:])
+	copy(rd.PublicKey, c.details.publicKey[:])
 
 	// I know, this is terrible
-	rd.Issuer, _ = hex.DecodeString(nc.details.Issuer)
+	rd.Issuer, _ = hex.DecodeString(c.details.issuer)
 
 	return rd
 }
 
-func (nc *certificateV1) String() string {
-	if nc == nil {
-		return "Certificate {}\n"
-	}
-
-	s := "NebulaCertificate {\n"
-	s += "\tDetails {\n"
-	s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name)
-
-	if len(nc.details.Ips) > 0 {
-		s += "\t\tIps: [\n"
-		for _, ip := range nc.details.Ips {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tIps: []\n"
-	}
-
-	if len(nc.details.Subnets) > 0 {
-		s += "\t\tSubnets: [\n"
-		for _, ip := range nc.details.Subnets {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tSubnets: []\n"
-	}
-
-	if len(nc.details.Groups) > 0 {
-		s += "\t\tGroups: [\n"
-		for _, g := range nc.details.Groups {
-			s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tGroups: []\n"
-	}
-
-	s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore)
-	s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter)
-	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA)
-	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer)
-	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey)
-	s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve)
-	s += "\t}\n"
-	fp, err := nc.Fingerprint()
-	if err == nil {
-		s += fmt.Sprintf("\tFingerprint: %s\n", fp)
+func (c *certificateV1) String() string {
+	b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
 	}
-	s += fmt.Sprintf("\tSignature: %x\n", nc.Signature())
-	s += "}"
-
-	return s
+	return string(b)
 }
 
-func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) {
-	pubKey := nc.details.PublicKey
-	nc.details.PublicKey = nil
-	rawCertNoKey, err := nc.Marshal()
+func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
+	pubKey := c.details.publicKey
+	c.details.publicKey = nil
+	rawCertNoKey, err := c.Marshal()
 	if err != nil {
 		return nil, err
 	}
-	nc.details.PublicKey = pubKey
+	c.details.publicKey = pubKey
 	return rawCertNoKey, nil
 }
 
-func (nc *certificateV1) Marshal() ([]byte, error) {
+func (c *certificateV1) Marshal() ([]byte, error) {
 	rc := RawNebulaCertificate{
-		Details:   nc.getRawDetails(),
-		Signature: nc.signature,
+		Details:   c.getRawDetails(),
+		Signature: c.signature,
 	}
 
 	return proto.Marshal(&rc)
 }
 
-func (nc *certificateV1) MarshalPEM() ([]byte, error) {
-	b, err := nc.Marshal()
+func (c *certificateV1) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
 	if err != nil {
 		return nil, err
 	}
 	return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
 }
 
-func (nc *certificateV1) MarshalJSON() ([]byte, error) {
-	fp, _ := nc.Fingerprint()
-	jc := m{
+func (c *certificateV1) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.marshalJSON())
+}
+
+func (c *certificateV1) marshalJSON() m {
+	fp, _ := c.Fingerprint()
+	return m{
+		"version": Version1,
 		"details": m{
-			"name":      nc.details.Name,
-			"ips":       nc.details.Ips,
-			"subnets":   nc.details.Subnets,
-			"groups":    nc.details.Groups,
-			"notBefore": nc.details.NotBefore,
-			"notAfter":  nc.details.NotAfter,
-			"publicKey": fmt.Sprintf("%x", nc.details.PublicKey),
-			"isCa":      nc.details.IsCA,
-			"issuer":    nc.details.Issuer,
-			"curve":     nc.details.Curve.String(),
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"publicKey":      fmt.Sprintf("%x", c.details.publicKey),
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+			"curve":          c.details.curve.String(),
 		},
 		"fingerprint": fp,
-		"signature":   fmt.Sprintf("%x", nc.Signature()),
+		"signature":   fmt.Sprintf("%x", c.Signature()),
 	}
-	return json.Marshal(jc)
 }
 
-func (nc *certificateV1) Copy() Certificate {
-	c := &certificateV1{
+func (c *certificateV1) Copy() Certificate {
+	nc := &certificateV1{
 		details: detailsV1{
-			Name:      nc.details.Name,
-			Groups:    make([]string, len(nc.details.Groups)),
-			Ips:       make([]netip.Prefix, len(nc.details.Ips)),
-			Subnets:   make([]netip.Prefix, len(nc.details.Subnets)),
-			NotBefore: nc.details.NotBefore,
-			NotAfter:  nc.details.NotAfter,
-			PublicKey: make([]byte, len(nc.details.PublicKey)),
-			IsCA:      nc.details.IsCA,
-			Issuer:    nc.details.Issuer,
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			publicKey: make([]byte, len(c.details.publicKey)),
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+			curve:     c.details.curve,
 		},
-		signature: make([]byte, len(nc.signature)),
+		signature: make([]byte, len(c.signature)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.signature, c.signature)
+	copy(nc.details.publicKey, c.details.publicKey)
+
+	return nc
+}
+
+func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV1{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		publicKey:      t.PublicKey,
+		isCA:           t.IsCA,
+		curve:          t.Curve,
+		issuer:         t.issuer,
+	}
+
+	return c.validate()
+}
+
+func (c *certificateV1) validate() error {
+	// Empty names are allowed
+
+	if len(c.details.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
+	// Continue to allow this behavior
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
+	}
+
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
 	}
 
-	copy(c.signature, nc.signature)
-	copy(c.details.Groups, nc.details.Groups)
-	copy(c.details.PublicKey, nc.details.PublicKey)
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
+		}
 
-	for i, p := range nc.details.Ips {
-		c.details.Ips[i] = p
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
 	}
 
-	for i, p := range nc.details.Subnets {
-		c.details.Subnets[i] = p
+	// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
+	// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
+	// unsafe networks would result in a different signature.
+
+	return nil
+}
+
+func (c *certificateV1) marshalForSigning() ([]byte, error) {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return nil, err
 	}
+	return b, nil
+}
 
-	return c
+func (c *certificateV1) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
 }
 
 // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
-func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) {
+// if the publicKey is provided here then it is not required to be present in `b`
+func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
 	if len(b) == 0 {
 		return nil, fmt.Errorf("nil byte array")
 	}
@@ -371,27 +414,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
 
 	nc := certificateV1{
 		details: detailsV1{
-			Name:      rc.Details.Name,
-			Groups:    make([]string, len(rc.Details.Groups)),
-			Ips:       make([]netip.Prefix, len(rc.Details.Ips)/2),
-			Subnets:   make([]netip.Prefix, len(rc.Details.Subnets)/2),
-			NotBefore: time.Unix(rc.Details.NotBefore, 0),
-			NotAfter:  time.Unix(rc.Details.NotAfter, 0),
-			PublicKey: make([]byte, len(rc.Details.PublicKey)),
-			IsCA:      rc.Details.IsCA,
-			Curve:     rc.Details.Curve,
+			name:           rc.Details.Name,
+			groups:         make([]string, len(rc.Details.Groups)),
+			networks:       make([]netip.Prefix, len(rc.Details.Ips)/2),
+			unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
+			notBefore:      time.Unix(rc.Details.NotBefore, 0),
+			notAfter:       time.Unix(rc.Details.NotAfter, 0),
+			publicKey:      make([]byte, len(rc.Details.PublicKey)),
+			isCA:           rc.Details.IsCA,
+			curve:          rc.Details.Curve,
 		},
 		signature: make([]byte, len(rc.Signature)),
 	}
 
 	copy(nc.signature, rc.Signature)
-	copy(nc.details.Groups, rc.Details.Groups)
-	nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer)
+	copy(nc.details.groups, rc.Details.Groups)
+	nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
 
-	if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey {
-		return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
+	if len(publicKey) > 0 {
+		nc.details.publicKey = publicKey
 	}
-	copy(nc.details.PublicKey, rc.Details.PublicKey)
+
+	copy(nc.details.publicKey, rc.Details.PublicKey)
 
 	var ip netip.Addr
 	for i, rawIp := range rc.Details.Ips {
@@ -399,7 +443,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
 			ip = int2addr(rawIp)
 		} else {
 			ones, _ := net.IPMask(int2ip(rawIp)).Size()
-			nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones)
+			nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
 		}
 	}
 
@@ -408,67 +452,16 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
 			ip = int2addr(rawIp)
 		} else {
 			ones, _ := net.IPMask(int2ip(rawIp)).Size()
-			nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones)
+			nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
 		}
 	}
 
-	return &nc, nil
-}
-
-func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) {
-	c := &certificateV1{
-		details: detailsV1{
-			Name:      t.Name,
-			Ips:       t.Networks,
-			Subnets:   t.UnsafeNetworks,
-			Groups:    t.Groups,
-			NotBefore: t.NotBefore,
-			NotAfter:  t.NotAfter,
-			PublicKey: t.PublicKey,
-			IsCA:      t.IsCA,
-			Curve:     t.Curve,
-			Issuer:    t.issuer,
-		},
-	}
-	b, err := proto.Marshal(c.getRawDetails())
+	err = nc.validate()
 	if err != nil {
 		return nil, err
 	}
 
-	var sig []byte
-
-	switch curve {
-	case Curve_CURVE25519:
-		signer := ed25519.PrivateKey(key)
-		sig = ed25519.Sign(signer, b)
-	case Curve_P256:
-		if client != nil {
-			sig, err = client.SignASN1(b)
-		} else {
-			signer := &ecdsa.PrivateKey{
-				PublicKey: ecdsa.PublicKey{
-					Curve: elliptic.P256(),
-				},
-				// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
-				D: new(big.Int).SetBytes(key),
-			}
-			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
-			signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
-
-			// We need to hash first for ECDSA
-			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
-			hashed := sha256.Sum256(b)
-			sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
-			if err != nil {
-				return nil, err
-			}
-		}
-	default:
-		return nil, fmt.Errorf("invalid curve: %s", c.details.Curve)
-	}
-
-	c.signature = sig
-	return c, nil
+	return &nc, nil
 }
 
 func ip2int(ip []byte) uint32 {

+ 218 - 0
cert/cert_v1_test.go

@@ -0,0 +1,218 @@
+package cert
+
+import (
+	"fmt"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"google.golang.org/protobuf/proto"
+)
+
+func TestCertificateV1_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	assert.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+
+	assert.Equal(t, nc.Version(), Version1)
+	assert.Equal(t, nc.Curve(), Curve_CURVE25519)
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV1_Expired(t *testing.T) {
+	nc := certificateV1{
+		details: detailsV1{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV1_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
+		string(b),
+	)
+}
+
+func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.NotNil(t, err)
+}
+
+// Ensure that upgrading the protobuf library does not change how certificates
+// are marshalled, since this would break signature verification
+func TestMarshalingCertificateV1Consistency(t *testing.T) {
+	before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC)
+	after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV1{
+		details: detailsV1{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
+		},
+		signature: []byte("1234567890abcedfghij1234567890ab"),
+	}
+
+	b, err := nc.Marshal()
+	require.Nil(t, err)
+	assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
+
+	b, err = proto.Marshal(nc.getRawDetails())
+	assert.Nil(t, err)
+	assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
+}
+
+func TestCertificateV1_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV1(t *testing.T) {
+	// Test that we don't panic with an invalid certificate (#332)
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV1(data, nil)
+	assert.EqualError(t, err, "encoded Details was nil")
+}
+
+func appendByteSlices(b ...[]byte) []byte {
+	retSlice := []byte{}
+	for _, v := range b {
+		retSlice = append(retSlice, v...)
+	}
+	return retSlice
+}
+
+func mustParsePrefixUnmapped(s string) netip.Prefix {
+	prefix := netip.MustParsePrefix(s)
+	return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits())
+}

+ 37 - 0
cert/cert_v2.asn1

@@ -0,0 +1,37 @@
+Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
+
+Name ::= UTF8String (SIZE (1..253))
+Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
+Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
+Curve ::= ENUMERATED {
+    curve25519 (0),
+    p256 (1)
+}
+
+-- The maximum size of a certificate must not exceed 65536 bytes
+Certificate ::= SEQUENCE {
+    details OCTET STRING,
+    curve Curve DEFAULT curve25519,
+    publicKey OCTET STRING,
+    -- signature(details + curve + publicKey) using the appropriate method for curve
+    signature OCTET STRING
+}
+
+Details ::= SEQUENCE {
+    name Name,
+
+    -- At least 1 ipv4 or ipv6 address must be present if isCA is false
+    networks SEQUENCE OF Network OPTIONAL,
+    unsafeNetworks SEQUENCE OF Network OPTIONAL,
+    groups SEQUENCE OF Name OPTIONAL,
+    isCA BOOLEAN DEFAULT false,
+    notBefore Time,
+    notAfter Time,
+
+    -- issuer is only required if isCA is false, if isCA is true then it must not be present
+    issuer OCTET STRING OPTIONAL,
+    ...
+    -- New fields can be added below here
+}
+
+END

+ 730 - 0
cert/cert_v2.go

@@ -0,0 +1,730 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net/netip"
+	"slices"
+	"time"
+
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+	"golang.org/x/crypto/curve25519"
+)
+
+const (
+	classConstructed     = 0x20
+	classContextSpecific = 0x80
+
+	TagCertDetails   = 0 | classConstructed | classContextSpecific
+	TagCertCurve     = 1 | classContextSpecific
+	TagCertPublicKey = 2 | classContextSpecific
+	TagCertSignature = 3 | classContextSpecific
+
+	TagDetailsName           = 0 | classContextSpecific
+	TagDetailsNetworks       = 1 | classConstructed | classContextSpecific
+	TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
+	TagDetailsGroups         = 3 | classConstructed | classContextSpecific
+	TagDetailsIsCA           = 4 | classContextSpecific
+	TagDetailsNotBefore      = 5 | classContextSpecific
+	TagDetailsNotAfter       = 6 | classContextSpecific
+	TagDetailsIssuer         = 7 | classContextSpecific
+)
+
+const (
+	// MaxCertificateSize is the maximum length a valid certificate can be
+	MaxCertificateSize = 65536
+
+	// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
+	MaxNameLength = 253
+
+	// MaxNetworkLength is the maximum length a network value can be.
+	// 16 bytes for an ipv6 address + 1 byte for the prefix length
+	MaxNetworkLength = 17
+)
+
+type certificateV2 struct {
+	details detailsV2
+
+	// RawDetails contains the entire asn.1 DER encoded Details struct
+	// This is to benefit forwards compatibility in signature checking.
+	// signature(RawDetails + Curve + PublicKey) == Signature
+	rawDetails []byte
+	curve      Curve
+	publicKey  []byte
+	signature  []byte
+}
+
+type detailsV2 struct {
+	name           string
+	networks       []netip.Prefix // MUST BE SORTED
+	unsafeNetworks []netip.Prefix // MUST BE SORTED
+	groups         []string
+	isCA           bool
+	notBefore      time.Time
+	notAfter       time.Time
+	issuer         string
+}
+
+func (c *certificateV2) Version() Version {
+	return Version2
+}
+
+func (c *certificateV2) Curve() Curve {
+	return c.curve
+}
+
+func (c *certificateV2) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV2) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV2) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV2) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV2) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV2) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV2) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV2) PublicKey() []byte {
+	return c.publicKey
+}
+
+func (c *certificateV2) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV2) Fingerprint() (string, error) {
+	if len(c.rawDetails) == 0 {
+		return "", ErrMissingDetails
+	}
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV2) CheckSignature(key []byte) bool {
+	if len(c.rawDetails) == 0 {
+		return false
+	}
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+
+	switch c.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV2) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.curve {
+		return ErrPublicPrivateCurveMismatch
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return ErrInvalidPrivateKey
+			}
+
+			if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return ErrInvalidPrivateKey
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.publicKey) {
+				return ErrPublicPrivateKeyMismatch
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return ErrInvalidPrivateKey
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.publicKey) {
+		return ErrPublicPrivateKeyMismatch
+	}
+
+	return nil
+}
+
+func (c *certificateV2) String() string {
+	mb, err := c.marshalJSON()
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+
+	b, err := json.MarshalIndent(mb, "", "\t")
+	if err != nil {
+		return fmt.Sprintf("<error marshalling certificate: %v>", err)
+	}
+	return string(b)
+}
+
+func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Skipping the curve and public key since those come across in a different part of the handshake
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) Marshal() ([]byte, error) {
+	if c.rawDetails == nil {
+		return nil, ErrEmptyRawDetails
+	}
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Add the curve only if its not the default value
+		if c.curve != Curve_CURVE25519 {
+			b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
+				b.AddBytes([]byte{byte(c.curve)})
+			})
+		}
+
+		// Add the public key if it is not empty
+		if c.publicKey != nil {
+			b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
+				b.AddBytes(c.publicKey)
+			})
+		}
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
+}
+
+func (c *certificateV2) MarshalJSON() ([]byte, error) {
+	b, err := c.marshalJSON()
+	if err != nil {
+		return nil, err
+	}
+	return json.Marshal(b)
+}
+
+func (c *certificateV2) marshalJSON() (m, error) {
+	fp, err := c.Fingerprint()
+	if err != nil {
+		return nil, err
+	}
+
+	return m{
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+		},
+		"version":     Version2,
+		"publicKey":   fmt.Sprintf("%x", c.publicKey),
+		"curve":       c.curve.String(),
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}, nil
+}
+
+func (c *certificateV2) Copy() Certificate {
+	nc := &certificateV2{
+		details: detailsV2{
+			name:      c.details.name,
+			notBefore: c.details.notBefore,
+			notAfter:  c.details.notAfter,
+			isCA:      c.details.isCA,
+			issuer:    c.details.issuer,
+		},
+		curve:      c.curve,
+		publicKey:  make([]byte, len(c.publicKey)),
+		signature:  make([]byte, len(c.signature)),
+		rawDetails: make([]byte, len(c.rawDetails)),
+	}
+
+	if c.details.groups != nil {
+		nc.details.groups = make([]string, len(c.details.groups))
+		copy(nc.details.groups, c.details.groups)
+	}
+
+	if c.details.networks != nil {
+		nc.details.networks = make([]netip.Prefix, len(c.details.networks))
+		copy(nc.details.networks, c.details.networks)
+	}
+
+	if c.details.unsafeNetworks != nil {
+		nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks))
+		copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+	}
+
+	copy(nc.rawDetails, c.rawDetails)
+	copy(nc.signature, c.signature)
+	copy(nc.publicKey, c.publicKey)
+
+	return nc
+}
+
+func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV2{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		isCA:           t.IsCA,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		issuer:         t.issuer,
+	}
+	c.curve = t.Curve
+	c.publicKey = t.PublicKey
+	return c.validate()
+}
+
+func (c *certificateV2) validate() error {
+	// Empty names are allowed
+
+	if len(c.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
+	}
+
+	hasV4Networks := false
+	hasV6Networks := false
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+
+		if network.Addr().Is4In6() {
+			return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
+		}
+
+		hasV4Networks = hasV4Networks || network.Addr().Is4()
+		hasV6Networks = hasV6Networks || network.Addr().Is6()
+	}
+
+	slices.SortFunc(c.details.networks, comparePrefix)
+	err := findDuplicatePrefix(c.details.networks)
+	if err != nil {
+		return err
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+
+		if !c.details.isCA {
+			if network.Addr().Is6() {
+				if !hasV6Networks {
+					return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
+				}
+			} else if network.Addr().Is4() {
+				if !hasV4Networks {
+					return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
+				}
+			}
+		}
+	}
+
+	slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
+	err = findDuplicatePrefix(c.details.unsafeNetworks)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *certificateV2) marshalForSigning() ([]byte, error) {
+	d, err := c.details.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("marshalling certificate details failed: %w", err)
+	}
+	c.rawDetails = d
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	return b, nil
+}
+
+func (c *certificateV2) setSignature(b []byte) error {
+	if len(b) == 0 {
+		return ErrEmptySignature
+	}
+	c.signature = b
+	return nil
+}
+
+func (d *detailsV2) Marshal() ([]byte, error) {
+	var b cryptobyte.Builder
+	var err error
+
+	// Details are a structure
+	b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
+
+		// Add the name
+		b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
+			b.AddBytes([]byte(d.name))
+		})
+
+		// Add the networks if any exist
+		if len(d.networks) > 0 {
+			b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.networks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add the unsafe networks if any exist
+		if len(d.unsafeNetworks) > 0 {
+			b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.unsafeNetworks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add groups if any exist
+		if len(d.groups) > 0 {
+			b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
+				for _, group := range d.groups {
+					b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
+						b.AddBytes([]byte(group))
+					})
+				}
+			})
+		}
+
+		// Add IsCA only if true
+		if d.isCA {
+			b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
+				b.AddUint8(0xff)
+			})
+		}
+
+		// Add not before
+		b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
+
+		// Add not after
+		b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
+
+		// Add the issuer if present
+		if d.issuer != "" {
+			issuerBytes, innerErr := hex.DecodeString(d.issuer)
+			if innerErr != nil {
+				err = fmt.Errorf("failed to decode issuer: %w", innerErr)
+				return
+			}
+			b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
+				b.AddBytes(issuerBytes)
+			})
+		}
+	})
+
+	if err != nil {
+		return nil, err
+	}
+
+	return b.Bytes()
+}
+
+func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
+	l := len(b)
+	if l == 0 || l > MaxCertificateSize {
+		return nil, ErrBadFormat
+	}
+
+	input := cryptobyte.String(b)
+	// Open the envelope
+	if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the cert details, we need to preserve the tag and length
+	var rawDetails cryptobyte.String
+	if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	//Maybe grab the curve
+	var rawCurve byte
+	if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
+		return nil, ErrBadFormat
+	}
+	curve = Curve(rawCurve)
+
+	// Maybe grab the public key
+	var rawPublicKey cryptobyte.String
+	if len(publicKey) > 0 {
+		rawPublicKey = publicKey
+	} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
+		return nil, ErrBadFormat
+	}
+
+	if len(rawPublicKey) == 0 {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the signature
+	var rawSignature cryptobyte.String
+	if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Finally unmarshal the details
+	details, err := unmarshalDetails(rawDetails)
+	if err != nil {
+		return nil, err
+	}
+
+	c := &certificateV2{
+		details:    details,
+		rawDetails: rawDetails,
+		curve:      curve,
+		publicKey:  rawPublicKey,
+		signature:  rawSignature,
+	}
+
+	err = c.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
+}
+
+func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
+	// Open the envelope
+	if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the name
+	var name cryptobyte.String
+	if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the network addresses
+	var subString cryptobyte.String
+	var found bool
+
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var networks []netip.Prefix
+	var val cryptobyte.String
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			networks = append(networks, n)
+		}
+	}
+
+	// Read out any unsafe networks
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var unsafeNetworks []netip.Prefix
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			unsafeNetworks = append(unsafeNetworks, n)
+		}
+	}
+
+	// Read out any groups
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var groups []string
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
+				return detailsV2{}, ErrBadFormat
+			}
+			groups = append(groups, string(val))
+		}
+	}
+
+	// Read out IsCA
+	var isCa bool
+	if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read not before and not after
+	var notBefore int64
+	if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var notAfter int64
+	if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read issuer
+	var issuer cryptobyte.String
+	if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	return detailsV2{
+		name:           string(name),
+		networks:       networks,
+		unsafeNetworks: unsafeNetworks,
+		groups:         groups,
+		isCA:           isCa,
+		notBefore:      time.Unix(notBefore, 0),
+		notAfter:       time.Unix(notAfter, 0),
+		issuer:         hex.EncodeToString(issuer),
+	}, nil
+}

+ 267 - 0
cert/cert_v2_test.go

@@ -0,0 +1,267 @@
+package cert
+
+import (
+	"crypto/ed25519"
+	"crypto/rand"
+	"encoding/hex"
+	"net/netip"
+	"slices"
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestCertificateV2_Marshal(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	nc.rawDetails = db
+
+	b, err := nc.Marshal()
+	require.Nil(t, err)
+	//t.Log("Cert size:", len(b))
+
+	nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
+	assert.Nil(t, err)
+
+	assert.Equal(t, nc.Version(), Version2)
+	assert.Equal(t, nc.Curve(), Curve_CURVE25519)
+	assert.Equal(t, nc.Signature(), nc2.Signature())
+	assert.Equal(t, nc.Name(), nc2.Name())
+	assert.Equal(t, nc.NotBefore(), nc2.NotBefore())
+	assert.Equal(t, nc.NotAfter(), nc2.NotAfter())
+	assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
+	assert.Equal(t, nc.IsCA(), nc2.IsCA())
+	assert.Equal(t, nc.Issuer(), nc2.Issuer())
+
+	// unmarshalling will sort networks and unsafeNetworks, we need to do the same
+	// but first make sure it fails
+	assert.NotEqual(t, nc.Networks(), nc2.Networks())
+	assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	slices.SortFunc(nc.details.networks, comparePrefix)
+	slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
+
+	assert.Equal(t, nc.Networks(), nc2.Networks())
+	assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks())
+
+	assert.Equal(t, nc.Groups(), nc2.Groups())
+}
+
+func TestCertificateV2_Expired(t *testing.T) {
+	nc := certificateV2{
+		details: detailsV2{
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+		},
+	}
+
+	assert.True(t, nc.Expired(time.Now().Add(time.Hour)))
+	assert.True(t, nc.Expired(time.Now().Add(-time.Hour)))
+	assert.False(t, nc.Expired(time.Now()))
+}
+
+func TestCertificateV2_MarshalJSON(t *testing.T) {
+	time.Local = time.UTC
+	pubKey := []byte("1234567890abcedf1234567890abcedf")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			isCA:      false,
+			issuer:    "1234567890abcedf1234567890abcedf",
+		},
+		publicKey: pubKey,
+		signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"),
+	}
+
+	b, err := nc.MarshalJSON()
+	assert.ErrorIs(t, err, ErrMissingDetails)
+
+	rd, err := nc.details.Marshal()
+	assert.NoError(t, err)
+
+	nc.rawDetails = rd
+	b, err = nc.MarshalJSON()
+	assert.Nil(t, err)
+	assert.Equal(
+		t,
+		"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
+		string(b),
+	)
+}
+
+func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	_, caKey2, err := ed25519.GenerateKey(rand.Reader)
+	require.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
+	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_CURVE25519, curve)
+	err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := X25519Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
+	assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
+
+	err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	ac, ok := c.(*certificateV2)
+	require.True(t, ok)
+	ac.curve = Curve(99)
+	err = c.VerifyPrivateKey(Curve(99), priv2)
+	assert.EqualError(t, err, "invalid curve: 99")
+
+	ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
+	assert.Nil(t, err)
+
+	err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv[:16])
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	err = c.VerifyPrivateKey(Curve_P256, priv)
+	assert.ErrorIs(t, err, ErrInvalidPrivateKey)
+
+	aCa, ok := ca2.(*certificateV2)
+	require.True(t, ok)
+	aCa.curve = Curve(99)
+	err = aCa.VerifyPrivateKey(Curve(99), priv2)
+	assert.EqualError(t, err, "invalid curve: 99")
+
+}
+
+func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	err := ca.VerifyPrivateKey(Curve_P256, caKey)
+	assert.Nil(t, err)
+
+	_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
+	assert.Nil(t, err)
+	err = ca.VerifyPrivateKey(Curve_P256, caKey2)
+	assert.NotNil(t, err)
+
+	c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
+	rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
+	assert.NoError(t, err)
+	assert.Empty(t, b)
+	assert.Equal(t, Curve_P256, curve)
+	err = c.VerifyPrivateKey(Curve_P256, rawPriv)
+	assert.Nil(t, err)
+
+	_, priv2 := P256Keypair()
+	err = c.VerifyPrivateKey(Curve_P256, priv2)
+	assert.NotNil(t, err)
+}
+
+func TestCertificateV2_Copy(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
+	c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
+	cc := c.Copy()
+	test.AssertDeepCopyEqual(t, c, cc)
+}
+
+func TestUnmarshalCertificateV2(t *testing.T) {
+	data := []byte("\x98\x00\x00")
+	_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
+	assert.EqualError(t, err, "bad wire format")
+}
+
+func TestCertificateV2_marshalForSigningStability(t *testing.T) {
+	before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC)
+	after := before.Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	nc := certificateV2{
+		details: detailsV2{
+			name: "testing",
+			networks: []netip.Prefix{
+				mustParsePrefixUnmapped("10.1.1.2/16"),
+				mustParsePrefixUnmapped("10.1.1.1/24"),
+			},
+			unsafeNetworks: []netip.Prefix{
+				mustParsePrefixUnmapped("9.1.1.3/16"),
+				mustParsePrefixUnmapped("9.1.1.2/24"),
+			},
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			isCA:      false,
+			issuer:    "1234567890abcdef1234567890abcdef",
+		},
+		signature: []byte("1234567890abcdef1234567890abcdef"),
+		publicKey: pubKey,
+	}
+
+	const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef"
+	expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr)
+	require.NoError(t, err)
+
+	db, err := nc.details.Marshal()
+	require.NoError(t, err)
+	assert.Equal(t, expectedRawDetails, db)
+
+	expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162")
+	b, err := nc.marshalForSigning()
+	require.NoError(t, err)
+	assert.Equal(t, expectedForSigning, b)
+}

+ 34 - 12
cert/errors.go

@@ -2,21 +2,24 @@ package cert
 
 import (
 	"errors"
+	"fmt"
 )
 
 var (
-	ErrBadFormat               = errors.New("bad wire format")
-	ErrRootExpired             = errors.New("root certificate is expired")
-	ErrExpired                 = errors.New("certificate is expired")
-	ErrNotCA                   = errors.New("certificate is not a CA")
-	ErrNotSelfSigned           = errors.New("certificate is not self-signed")
-	ErrBlockListed             = errors.New("certificate is in the block list")
-	ErrFingerprintMismatch     = errors.New("certificate fingerprint did not match")
-	ErrSignatureMismatch       = errors.New("certificate signature did not match")
-	ErrInvalidPublicKeyLength  = errors.New("invalid public key length")
-	ErrInvalidPrivateKeyLength = errors.New("invalid private key length")
-
-	ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
+	ErrBadFormat                  = errors.New("bad wire format")
+	ErrRootExpired                = errors.New("root certificate is expired")
+	ErrExpired                    = errors.New("certificate is expired")
+	ErrNotCA                      = errors.New("certificate is not a CA")
+	ErrNotSelfSigned              = errors.New("certificate is not self-signed")
+	ErrBlockListed                = errors.New("certificate is in the block list")
+	ErrFingerprintMismatch        = errors.New("certificate fingerprint did not match")
+	ErrSignatureMismatch          = errors.New("certificate signature did not match")
+	ErrInvalidPublicKey           = errors.New("invalid public key")
+	ErrInvalidPrivateKey          = errors.New("invalid private key")
+	ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
+	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
+	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
+	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
 
 	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
 	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")
@@ -24,4 +27,23 @@ var (
 	ErrInvalidPEMX25519PrivateKeyBanner  = errors.New("bytes did not contain a proper X25519 private key banner")
 	ErrInvalidPEMEd25519PublicKeyBanner  = errors.New("bytes did not contain a proper Ed25519 public key banner")
 	ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
+
+	ErrNoPeerStaticKey = errors.New("no peer static key was present")
+	ErrNoPayload       = errors.New("provided payload was empty")
+
+	ErrMissingDetails  = errors.New("certificate did not contain details")
+	ErrEmptySignature  = errors.New("empty signature")
+	ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
 )
+
+type ErrInvalidCertificateProperties struct {
+	str string
+}
+
+func NewErrInvalidCertificateProperties(format string, a ...any) error {
+	return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
+}
+
+func (e *ErrInvalidCertificateProperties) Error() string {
+	return e.str
+}

+ 141 - 0
cert/helper_test.go

@@ -0,0 +1,141 @@
+package cert
+
+import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"io"
+	"net/netip"
+	"time"
+
+	"golang.org/x/crypto/curve25519"
+	"golang.org/x/crypto/ed25519"
+)
+
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	t := &TBSCertificate{
+		Curve:          curve,
+		Version:        version,
+		Name:           "test ca",
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		IsCA:           true,
+	}
+
+	c, err := t.Sign(nil, curve, priv)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, priv, pem
+}
+
+// NewTestCert will generate a signed certificate with the provided details.
+// Expiry times are defaulted if you do not pass them in
+func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) {
+	if before.IsZero() {
+		before = time.Now().Add(time.Second * -60).Round(time.Second)
+	}
+
+	if after.IsZero() {
+		after = time.Now().Add(time.Second * 60).Round(time.Second)
+	}
+
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
+	var pub, priv []byte
+	switch curve {
+	case Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
+	nc := &TBSCertificate{
+		Version:        v,
+		Curve:          curve,
+		Name:           name,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
+		Groups:         groups,
+		NotBefore:      time.Unix(before.Unix(), 0),
+		NotAfter:       time.Unix(after.Unix(), 0),
+		PublicKey:      pub,
+		IsCA:           false,
+	}
+
+	c, err := nc.Sign(ca, ca.Curve(), key)
+	if err != nil {
+		panic(err)
+	}
+
+	pem, err := c.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem
+}
+
+func X25519Keypair() ([]byte, []byte) {
+	privkey := make([]byte, 32)
+	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
+		panic(err)
+	}
+
+	pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint)
+	if err != nil {
+		panic(err)
+	}
+
+	return pubkey, privkey
+}
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 13 - 7
cert/pem.go

@@ -30,19 +30,25 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
 		return nil, r, ErrInvalidPEMBlock
 	}
 
+	var c Certificate
+	var err error
+
 	switch p.Type {
+	// Implementations must validate the resulting certificate contains valid information
 	case CertificateBanner:
-		c, err := unmarshalCertificateV1(p.Bytes, true)
-		if err != nil {
-			return nil, nil, err
-		}
-		return c, r, nil
+		c, err = unmarshalCertificateV1(p.Bytes, nil)
 	case CertificateV2Banner:
-		//TODO
-		panic("TODO")
+		c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
 	default:
 		return nil, r, ErrInvalidPEMCertificateBanner
 	}
+
+	if err != nil {
+		return nil, r, err
+	}
+
+	return c, r, nil
+
 }
 
 func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {

+ 105 - 14
cert/sign.go

@@ -1,11 +1,15 @@
 package cert
 
 import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/sha256"
 	"fmt"
+	"math/big"
 	"net/netip"
 	"time"
-
-	"github.com/slackhq/nebula/pkclient"
 )
 
 // TBSCertificate represents a certificate intended to be signed.
@@ -24,28 +28,61 @@ type TBSCertificate struct {
 	issuer         string
 }
 
+type beingSignedCertificate interface {
+	// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
+	// Implementations must validate the resulting certificate contains valid information
+	fromTBSCertificate(*TBSCertificate) error
+
+	// marshalForSigning returns the bytes that should be signed
+	marshalForSigning() ([]byte, error)
+
+	// setSignature sets the signature for the certificate that has just been signed. The signature must not be blank.
+	setSignature([]byte) error
+}
+
+type SignerLambda func(certBytes []byte) ([]byte, error)
+
 // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
 // details do not violate constraints of the signing certificate.
 // If the TBSCertificate is a CA then signer must be nil.
 func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
-	return t.sign(signer, curve, key, nil)
-}
-
-func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) {
-	if curve != Curve_P256 {
-		return nil, fmt.Errorf("only P256 is supported by PKCS#11")
+	switch t.Curve {
+	case Curve_CURVE25519:
+		pk := ed25519.PrivateKey(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			sig := ed25519.Sign(pk, certBytes)
+			return sig, nil
+		}
+		return t.SignWith(signer, curve, sp)
+	case Curve_P256:
+		pk := &ecdsa.PrivateKey{
+			PublicKey: ecdsa.PublicKey{
+				Curve: elliptic.P256(),
+			},
+			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
+			D: new(big.Int).SetBytes(key),
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
+		pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			// We need to hash first for ECDSA
+			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
+			hashed := sha256.Sum256(certBytes)
+			return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
+		}
+		return t.SignWith(signer, curve, sp)
+	default:
+		return nil, fmt.Errorf("invalid curve: %s", t.Curve)
 	}
-
-	return t.sign(signer, curve, nil, client)
 }
 
-func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) {
+// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
+// You should only use SignWith if you do not have direct access to your private key.
+func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
 	if curve != t.Curve {
 		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
 	}
 
-	//TODO: make sure we have all minimum properties to sign, like a public key
-
 	if signer != nil {
 		if t.IsCA {
 			return nil, fmt.Errorf("can not sign a CA certificate with another")
@@ -67,10 +104,64 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien
 		}
 	}
 
+	var c beingSignedCertificate
 	switch t.Version {
 	case Version1:
-		return signV1(t, curve, key, client)
+		c = &certificateV1{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	case Version2:
+		c = &certificateV2{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
 	default:
 		return nil, fmt.Errorf("unknown cert version %d", t.Version)
 	}
+
+	certBytes, err := c.marshalForSigning()
+	if err != nil {
+		return nil, err
+	}
+
+	sig, err := sp(certBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	err = c.setSignature(sig)
+	if err != nil {
+		return nil, err
+	}
+
+	sc, ok := c.(Certificate)
+	if !ok {
+		return nil, fmt.Errorf("invalid certificate")
+	}
+
+	return sc, nil
+}
+
+func comparePrefix(a, b netip.Prefix) int {
+	addr := a.Addr().Compare(b.Addr())
+	if addr == 0 {
+		return a.Bits() - b.Bits()
+	}
+	return addr
+}
+
+// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
+func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
+	if len(sortedPrefixes) < 2 {
+		return nil
+	}
+	for i := 1; i < len(sortedPrefixes); i++ {
+		if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
+			return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
+		}
+	}
+	return nil
 }

+ 90 - 0
cert/sign_test.go

@@ -0,0 +1,90 @@
+package cert
+
+import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"net/netip"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCertificateV1_Sign(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/24"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+	}
+
+	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
+	assert.Nil(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	assert.Nil(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, uc)
+}
+
+func TestCertificateV1_SignP256(t *testing.T) {
+	before := time.Now().Add(time.Second * -60).Round(time.Second)
+	after := time.Now().Add(time.Second * 60).Round(time.Second)
+	pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab")
+
+	tbs := TBSCertificate{
+		Version: Version1,
+		Name:    "testing",
+		Networks: []netip.Prefix{
+			mustParsePrefixUnmapped("10.1.1.1/24"),
+			mustParsePrefixUnmapped("10.1.1.2/16"),
+		},
+		UnsafeNetworks: []netip.Prefix{
+			mustParsePrefixUnmapped("9.1.1.2/24"),
+			mustParsePrefixUnmapped("9.1.1.3/16"),
+		},
+		Groups:    []string{"test-group1", "test-group2", "test-group3"},
+		NotBefore: before,
+		NotAfter:  after,
+		PublicKey: pubKey,
+		IsCA:      false,
+		Curve:     Curve_P256,
+	}
+
+	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+	assert.NoError(t, err)
+	pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
+	rawPriv := priv.D.FillBytes(make([]byte, 32))
+
+	c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
+	assert.Nil(t, err)
+	assert.NotNil(t, c)
+	assert.True(t, c.CheckSignature(pub))
+
+	b, err := c.Marshal()
+	assert.Nil(t, err)
+	uc, err := unmarshalCertificateV1(b, nil)
+	assert.Nil(t, err)
+	assert.NotNil(t, uc)
+}

+ 51 - 11
e2e/helpers.go → cert_test/cert.go

@@ -1,6 +1,9 @@
-package e2e
+package cert_test
 
 import (
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rand"
 	"io"
 	"net/netip"
@@ -11,9 +14,26 @@ import (
 	"golang.org/x/crypto/ed25519"
 )
 
-// NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
-	pub, priv, err := ed25519.GenerateKey(rand.Reader)
+// NewTestCaCert will create a new ca certificate
+func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+	var err error
+	var pub, priv []byte
+
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv, err = ed25519.GenerateKey(rand.Reader)
+	case cert.Curve_P256:
+		privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+		if err != nil {
+			panic(err)
+		}
+
+		pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y)
+		priv = privk.D.FillBytes(make([]byte, 32))
+	default:
+		// There is no default to allow the underlying lib to respond with an error
+	}
+
 	if before.IsZero() {
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
 	}
@@ -22,7 +42,8 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
 	}
 
 	t := &cert.TBSCertificate{
-		Version:        cert.Version1,
+		Curve:          curve,
+		Version:        version,
 		Name:           "test ca",
 		NotBefore:      time.Unix(before.Unix(), 0),
 		NotAfter:       time.Unix(after.Unix(), 0),
@@ -33,7 +54,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
 		IsCA:           true,
 	}
 
-	c, err := t.Sign(nil, cert.Curve_CURVE25519, priv)
+	c, err := t.Sign(nil, curve, priv)
 	if err != nil {
 		panic(err)
 	}
@@ -48,7 +69,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
 
 // NewTestCert will generate a signed certificate with the provided details.
 // Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
 	if before.IsZero() {
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
 	}
@@ -57,9 +78,19 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
 		after = time.Now().Add(time.Second * 60).Round(time.Second)
 	}
 
-	pub, rawPriv := x25519Keypair()
+	var pub, priv []byte
+	switch curve {
+	case cert.Curve_CURVE25519:
+		pub, priv = X25519Keypair()
+	case cert.Curve_P256:
+		pub, priv = P256Keypair()
+	default:
+		panic("unknown curve")
+	}
+
 	nc := &cert.TBSCertificate{
-		Version:        cert.Version1,
+		Version:        v,
+		Curve:          curve,
 		Name:           name,
 		Networks:       networks,
 		UnsafeNetworks: unsafeNetworks,
@@ -80,10 +111,10 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
 		panic(err)
 	}
 
-	return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem
+	return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem
 }
 
-func x25519Keypair() ([]byte, []byte) {
+func X25519Keypair() ([]byte, []byte) {
 	privkey := make([]byte, 32)
 	if _, err := io.ReadFull(rand.Reader, privkey); err != nil {
 		panic(err)
@@ -96,3 +127,12 @@ func x25519Keypair() ([]byte, []byte) {
 
 	return pubkey, privkey
 }
+
+func P256Keypair() ([]byte, []byte) {
+	privkey, err := ecdh.P256().GenerateKey(rand.Reader)
+	if err != nil {
+		panic(err)
+	}
+	pubkey := privkey.PublicKey()
+	return pubkey.Bytes(), privkey.Bytes()
+}

+ 46 - 22
cmd/nebula-cert/ca.go

@@ -27,34 +27,43 @@ type caFlags struct {
 	outCertPath      *string
 	outQRPath        *string
 	groups           *string
-	ips              *string
-	subnets          *string
+	networks         *string
+	unsafeNetworks   *string
 	argonMemory      *uint
 	argonIterations  *uint
 	argonParallelism *uint
 	encryption       *bool
+	version          *uint
 
 	curve  *string
 	p11url *string
+
+	// Deprecated options
+	ips     *string
+	subnets *string
 }
 
 func newCaFlags() *caFlags {
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf.set.Usage = func() {}
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
+	cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
-	cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
-	cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
+	cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
+	cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
 	cf.p11url = p11Flag(cf.set)
+
+	cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
+	cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &cf
 }
 
@@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		}
 	}
 
-	var ips []netip.Prefix
-	if *cf.ips != "" {
-		for _, rs := range strings.Split(*cf.ips, ",") {
+	version := cert.Version(*cf.version)
+	if version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
+	var networks []netip.Prefix
+	if *cf.networks == "" && *cf.ips != "" {
+		// Pull up deprecated -ips flag if needed
+		*cf.networks = *cf.ips
+	}
+
+	if *cf.networks != "" {
+		for _, rs := range strings.Split(*cf.networks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid ip definition: %s", err)
+					return newHelpErrorf("invalid -networks definition: %s", rs)
 				}
-				if !n.Addr().Is4() {
-					return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
-				ips = append(ips, n)
+				networks = append(networks, n)
 			}
 		}
 	}
 
-	var subnets []netip.Prefix
-	if *cf.subnets != "" {
-		for _, rs := range strings.Split(*cf.subnets, ",") {
+	var unsafeNetworks []netip.Prefix
+	if *cf.unsafeNetworks == "" && *cf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*cf.unsafeNetworks = *cf.subnets
+	}
+
+	if *cf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
-				if !n.Addr().Is4() {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
-				subnets = append(subnets, n)
+				unsafeNetworks = append(unsafeNetworks, n)
 			}
 		}
 	}
@@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 	}
 
 	t := &cert.TBSCertificate{
-		Version:        cert.Version1,
+		Version:        version,
 		Name:           *cf.name,
 		Groups:         groups,
-		Networks:       ips,
-		UnsafeNetworks: subnets,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
 		NotBefore:      time.Now(),
 		NotAfter:       time.Now().Add(*cf.duration),
 		PublicKey:      pub,
@@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 	var b []byte
 
 	if isP11 {
-		c, err = t.SignPkcs11(nil, curve, p11Client)
+		c, err = t.SignWith(nil, curve, p11Client.SignASN1)
 		if err != nil {
 			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 		}

+ 20 - 16
cmd/nebula-cert/ca_test.go

@@ -16,8 +16,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//TODO: test file permissions
-
 func Test_caSummary(t *testing.T) {
 	assert.Equal(t, "ca <flags>: create a self signed certificate authority", caSummary())
 }
@@ -43,9 +41,11 @@ func Test_caHelp(t *testing.T) {
 			"  -groups string\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"  -ips string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
+			"    	Deprecated, see -networks\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the certificate authority\n"+
+			"  -networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"  -out-key string\n"+
@@ -54,7 +54,11 @@ func Test_caHelp(t *testing.T) {
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use (default 2)\n",
 		ob.String(),
 	)
 }
@@ -83,25 +87,25 @@ func Test_ca(t *testing.T) {
 
 	// required args
 	assertHelpError(t, ca(
-		[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
+		[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
 	), "-name is required")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// ipv4 only ips
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// ipv4 only subnets
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 
 	// failed key write
 	ob.Reset()
 	eb.Reset()
-	args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
+	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -114,7 +118,7 @@ func Test_ca(t *testing.T) {
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -128,7 +132,7 @@ func Test_ca(t *testing.T) {
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -161,7 +165,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
@@ -189,7 +193,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
@@ -199,7 +203,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, "", eb.String())
@@ -209,13 +213,13 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
@@ -224,7 +228,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())

+ 0 - 2
cmd/nebula-cert/keygen_test.go

@@ -9,8 +9,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//TODO: test file permissions
-
 func Test_keygenSummary(t *testing.T) {
 	assert.Equal(t, "keygen <flags>: create a public/private key pair. the public key can be passed to `nebula-cert sign`", keygenSummary())
 }

+ 0 - 2
cmd/nebula-cert/main_test.go

@@ -11,8 +11,6 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//TODO: all flag parsing continueOnError will print to stderr on its own currently
-
 func Test_help(t *testing.T) {
 	expected := "Usage of " + os.Args[0] + " <global flags> <mode>:\n" +
 		"  Global flags:\n" +

+ 11 - 6
cmd/nebula-cert/print.go

@@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 	var qrBytes []byte
 	part := 0
 
+	var jsonCerts []cert.Certificate
+
 	for {
 		c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
 		if err != nil {
@@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		}
 
 		if *pf.json {
-			b, _ := json.Marshal(c)
-			out.Write(b)
-			out.Write([]byte("\n"))
-
+			jsonCerts = append(jsonCerts, c)
 		} else {
-			out.Write([]byte(c.String()))
-			out.Write([]byte("\n"))
+			_, _ = out.Write([]byte(c.String()))
+			_, _ = out.Write([]byte("\n"))
 		}
 
 		if *pf.outQRPath != "" {
@@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		part++
 	}
 
+	if *pf.json {
+		b, _ := json.Marshal(jsonCerts)
+		_, _ = out.Write(b)
+		_, _ = out.Write([]byte("\n"))
+	}
+
 	if *pf.outQRPath != "" {
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		if err != nil {

+ 72 - 3
cmd/nebula-cert/print_test.go

@@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) {
 	tf.Truncate(0)
 	tf.Seek(0, 0)
 	ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
-	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"})
+	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
 
 	p, _ := c.MarshalPEM()
 	tf.Write(p)
@@ -87,7 +87,71 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		`{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [
+			"10.0.0.123/8"
+		],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+`,
 		ob.String(),
 	)
 	assert.Equal(t, "", eb.String())
@@ -108,7 +172,8 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n",
+		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
+`,
 		ob.String(),
 	)
 	assert.Equal(t, "", eb.String())
@@ -153,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft
 		after = ca.NotAfter()
 	}
 
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
 	pub, rawPriv := x25519Keypair()
 	nc := &cert.TBSCertificate{
 		Version:        cert.Version1,

+ 166 - 66
cmd/nebula-cert/sign.go

@@ -3,6 +3,7 @@ package main
 import (
 	"crypto/ecdh"
 	"crypto/rand"
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -18,36 +19,46 @@ import (
 )
 
 type signFlags struct {
-	set         *flag.FlagSet
-	caKeyPath   *string
-	caCertPath  *string
-	name        *string
-	ip          *string
-	duration    *time.Duration
-	inPubPath   *string
-	outKeyPath  *string
-	outCertPath *string
-	outQRPath   *string
-	groups      *string
-	subnets     *string
-	p11url      *string
+	set            *flag.FlagSet
+	version        *uint
+	caKeyPath      *string
+	caCertPath     *string
+	name           *string
+	networks       *string
+	unsafeNetworks *string
+	duration       *time.Duration
+	inPubPath      *string
+	outKeyPath     *string
+	outCertPath    *string
+	outQRPath      *string
+	groups         *string
+
+	p11url *string
+
+	// Deprecated options
+	ip      *string
+	subnets *string
 }
 
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
-	sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
+	sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
+	sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
-	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
 	sf.p11url = p11Flag(sf.set)
+
+	sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
+	sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &sf
 }
 
@@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	if err := mustFlagString("name", sf.name); err != nil {
 		return err
 	}
-	if err := mustFlagString("ip", sf.ip); err != nil {
-		return err
-	}
 	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 
+	var v4Networks []netip.Prefix
+	var v6Networks []netip.Prefix
+	if *sf.networks == "" && *sf.ip != "" {
+		// Pull up deprecated -ip flag if needed
+		*sf.networks = *sf.ip
+	}
+
+	if len(*sf.networks) == 0 {
+		return newHelpErrorf("-networks is required")
+	}
+
+	version := cert.Version(*sf.version)
+	if version != 0 && version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
 	var curve cert.Curve
 	var caKey []byte
 
@@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 		// naively attempt to decode the private key as though it is not encrypted
 		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
-		if err == cert.ErrPrivateKeyEncrypted {
+		if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
 			// ask for a passphrase until we get one
 			var passphrase []byte
 			for i := 0; i < 5; i++ {
 				out.Write([]byte("Enter passphrase: "))
 				passphrase, err = pr.ReadPassword()
 
-				if err == ErrNoTerminal {
+				if errors.Is(err, ErrNoTerminal) {
 					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
 				} else if err != nil {
 					return fmt.Errorf("error reading password: %s", err)
@@ -146,37 +170,55 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
 	}
 
-	network, err := netip.ParsePrefix(*sf.ip)
-	if err != nil {
-		return newHelpErrorf("invalid ip definition: %s", *sf.ip)
-	}
-	if !network.Addr().Is4() {
-		return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
-	}
+	if *sf.networks != "" {
+		for _, rs := range strings.Split(*sf.networks, ",") {
+			rs := strings.Trim(rs, " ")
+			if rs != "" {
+				n, err := netip.ParsePrefix(rs)
+				if err != nil {
+					return newHelpErrorf("invalid -networks definition: %s", rs)
+				}
 
-	var groups []string
-	if *sf.groups != "" {
-		for _, rg := range strings.Split(*sf.groups, ",") {
-			g := strings.TrimSpace(rg)
-			if g != "" {
-				groups = append(groups, g)
+				if n.Addr().Is4() {
+					v4Networks = append(v4Networks, n)
+				} else {
+					v6Networks = append(v6Networks, n)
+				}
 			}
 		}
 	}
 
-	var subnets []netip.Prefix
-	if *sf.subnets != "" {
-		for _, rs := range strings.Split(*sf.subnets, ",") {
+	var v4UnsafeNetworks []netip.Prefix
+	var v6UnsafeNetworks []netip.Prefix
+	if *sf.unsafeNetworks == "" && *sf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*sf.unsafeNetworks = *sf.subnets
+	}
+
+	if *sf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
-				s, err := netip.ParsePrefix(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", rs)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
-				if !s.Addr().Is4() {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+
+				if n.Addr().Is4() {
+					v4UnsafeNetworks = append(v4UnsafeNetworks, n)
+				} else {
+					v6UnsafeNetworks = append(v6UnsafeNetworks, n)
 				}
-				subnets = append(subnets, s)
+			}
+		}
+	}
+
+	var groups []string
+	if *sf.groups != "" {
+		for _, rg := range strings.Split(*sf.groups, ",") {
+			g := strings.TrimSpace(rg)
+			if g != "" {
+				groups = append(groups, g)
 			}
 		}
 	}
@@ -218,19 +260,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		pub, rawPriv = newKeypair(curve)
 	}
 
-	t := &cert.TBSCertificate{
-		Version:        cert.Version1,
-		Name:           *sf.name,
-		Networks:       []netip.Prefix{network},
-		Groups:         groups,
-		UnsafeNetworks: subnets,
-		NotBefore:      time.Now(),
-		NotAfter:       time.Now().Add(*sf.duration),
-		PublicKey:      pub,
-		IsCA:           false,
-		Curve:          curve,
-	}
-
 	if *sf.outKeyPath == "" {
 		*sf.outKeyPath = *sf.name + ".key"
 	}
@@ -243,18 +272,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 
-	var c cert.Certificate
+	var crts []cert.Certificate
 
-	if p11Client == nil {
-		c, err = t.Sign(caCert, curve, caKey)
-		if err != nil {
-			return fmt.Errorf("error while signing: %w", err)
+	notBefore := time.Now()
+	notAfter := notBefore.Add(*sf.duration)
+
+	if version == 0 || version == cert.Version1 {
+		// Make sure we at least have an ip
+		if len(v4Networks) != 1 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
 		}
-	} else {
-		c, err = t.SignPkcs11(caCert, curve, p11Client)
-		if err != nil {
-			return fmt.Errorf("error while signing with PKCS#11: %w", err)
+
+		if version == cert.Version1 {
+			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
+			if len(v6Networks) > 0 {
+				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
+			}
+
+			if len(v6UnsafeNetworks) > 0 {
+				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+			}
+		}
+
+		t := &cert.TBSCertificate{
+			Version:        cert.Version1,
+			Name:           *sf.name,
+			Networks:       []netip.Prefix{v4Networks[0]},
+			Groups:         groups,
+			UnsafeNetworks: v4UnsafeNetworks,
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
 		}
+
+		crts = append(crts, nc)
+	}
+
+	if version == 0 || version == cert.Version2 {
+		t := &cert.TBSCertificate{
+			Version:        cert.Version2,
+			Name:           *sf.name,
+			Networks:       append(v4Networks, v6Networks...),
+			Groups:         groups,
+			UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
 	}
 
 	if !isP11 && *sf.inPubPath == "" {
@@ -268,9 +364,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		}
 	}
 
-	b, err := c.MarshalPEM()
-	if err != nil {
-		return fmt.Errorf("error while marshalling certificate: %s", err)
+	var b []byte
+	for _, c := range crts {
+		sb, err := c.MarshalPEM()
+		if err != nil {
+			return fmt.Errorf("error while marshalling certificate: %s", err)
+		}
+		b = append(b, sb...)
 	}
 
 	err = os.WriteFile(*sf.outCertPath, b, 0600)

+ 47 - 36
cmd/nebula-cert/sign_test.go

@@ -16,8 +16,6 @@ import (
 	"golang.org/x/crypto/ed25519"
 )
 
-//TODO: test file permissions
-
 func Test_signSummary(t *testing.T) {
 	assert.Equal(t, "sign <flags>: create and sign a certificate", signSummary())
 }
@@ -39,9 +37,11 @@ func Test_signHelp(t *testing.T) {
 			"  -in-pub string\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"  -ip string\n"+
-			"    \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
+			"    \tDeprecated, see -networks\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
+			"  -networks string\n"+
+			"    \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"  -out-key string\n"+
@@ -50,7 +50,11 @@ func Test_signHelp(t *testing.T) {
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
 		ob.String(),
 	)
 }
@@ -77,20 +81,20 @@ func Test_signCert(t *testing.T) {
 
 	// required args
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
 	), "-name is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
-	), "-ip is required")
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-networks is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// cannot set -in-pub and -out-key
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
 	), "cannot set both -in-pub and -out-key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -98,7 +102,7 @@ func Test_signCert(t *testing.T) {
 	// failed to read key
 	ob.Reset()
 	eb.Reset()
-	args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 	// failed to unmarshal key
@@ -108,7 +112,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -120,7 +124,7 @@ func Test_signCert(t *testing.T) {
 	caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
 
 	// failed to read cert
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -132,7 +136,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -143,7 +147,7 @@ func Test_signCert(t *testing.T) {
 	caCrtF.Write(b)
 
 	// failed to read pub
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -155,7 +159,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	defer os.Remove(inPubF.Name())
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -169,30 +173,37 @@ func Test_signCert(t *testing.T) {
 	// bad ip cidr
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	// bad subnet cidr
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 
@@ -205,7 +216,7 @@ func Test_signCert(t *testing.T) {
 
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -213,7 +224,7 @@ func Test_signCert(t *testing.T) {
 	// failed key write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -226,7 +237,7 @@ func Test_signCert(t *testing.T) {
 	// failed cert write
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -240,7 +251,7 @@ func Test_signCert(t *testing.T) {
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -283,7 +294,7 @@ func Test_signCert(t *testing.T) {
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -300,7 +311,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -308,14 +319,14 @@ func Test_signCert(t *testing.T) {
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -323,14 +334,14 @@ func Test_signCert(t *testing.T) {
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	ob.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
@@ -362,7 +373,7 @@ func Test_signCert(t *testing.T) {
 	caCrtF.Write(b)
 
 	// test with the proper password
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
@@ -372,7 +383,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 
 	testpw.password = []byte("invalid password")
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
@@ -381,7 +392,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, nopw))
 	// normally the user hitting enter on the prompt would add newlines between these
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@@ -391,7 +402,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, errpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())

+ 25 - 14
cmd/nebula-cert/verify.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
-		return fmt.Errorf("error while reading ca: %s", err)
+		return fmt.Errorf("error while reading ca: %w", err)
 	}
 
 	caPool := cert.NewCAPool()
 	for {
 		rawCACert, err = caPool.AddCAFromPEM(rawCACert)
 		if err != nil {
-			return fmt.Errorf("error while adding ca cert to pool: %s", err)
+			return fmt.Errorf("error while adding ca cert to pool: %w", err)
 		}
 
 		if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
-		return fmt.Errorf("unable to read crt; %s", err)
+		return fmt.Errorf("unable to read crt: %w", err)
 	}
-
-	c, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
-	if err != nil {
-		return fmt.Errorf("error while parsing crt: %s", err)
-	}
-
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	if err != nil {
-		return err
+	var errs []error
+	for {
+		if len(rawCert) == 0 {
+			break
+		}
+		c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
+		if err != nil {
+			return fmt.Errorf("error while parsing crt: %w", err)
+		}
+		rawCert = extra
+		_, err = caPool.VerifyCertificate(time.Now(), c)
+		if err != nil {
+			switch {
+			case errors.Is(err, cert.ErrCaNotFound):
+				errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
+			default:
+				errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
+			}
+		}
 	}
 
-	return nil
+	return errors.Join(errs...)
 }
 
 func verifySummary() string {
@@ -80,7 +91,7 @@ func verifySummary() string {
 
 func verifyHelp(out io.Writer) {
 	vf := newVerifyFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
 	vf.set.SetOutput(out)
 	vf.set.PrintDefaults()
 }

+ 4 - 2
cmd/nebula-cert/verify_test.go

@@ -3,10 +3,12 @@ package main
 import (
 	"bytes"
 	"crypto/rand"
+	"errors"
 	"os"
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/ed25519"
 )
@@ -76,7 +78,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
+	assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 	// invalid crt at path
 	ob.Reset()
@@ -106,7 +108,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "certificate signature did not match")
+	assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
 
 	// verified cert at path
 	crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)

+ 0 - 3
config/config_test.go

@@ -38,9 +38,6 @@ func TestConfig_Load(t *testing.T) {
 		"new": "hi",
 	}
 	assert.Equal(t, expected, c.Settings)
-
-	//TODO: test symlinked file
-	//TODO: test symlinked directory
 }
 
 func TestConfig_Get(t *testing.T) {

+ 57 - 32
connection_manager.go

@@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 	case deleteTunnel:
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 
 	case closeTunnel:
@@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 	for _, r := range relayFor {
-		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
 
 		var index uint32
 		var relayFrom netip.Addr
@@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = existing.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = existing.PeerAddr
 			case ForwardingType:
-				relayFrom = existing.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = existing.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
 			}
@@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			n.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = r.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = r.PeerAddr
 			case ForwardingType:
-				relayFrom = r.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = r.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 				// should never happen
 			}
 		}
 
-		//TODO: IPV6-WORK
-		relayFromB := relayFrom.As4()
-		relayToB := relayTo.As4()
-
 		// Send a CreateRelayRequest to the peer.
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         binary.BigEndian.Uint32(relayFromB[:]),
-			RelayToIp:           binary.BigEndian.Uint32(relayToB[:]),
 		}
+
+		switch newhostinfo.GetCert().Certificate.Version() {
+		case cert.Version1:
+			if !relayFrom.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				continue
+			}
+
+			if !relayTo.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				continue
+			}
+
+			b := relayFrom.As4()
+			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = relayTo.As4()
+			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		case cert.Version2:
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
+			req.RelayToAddr = netAddrToProtoAddr(relayTo)
+		default:
+			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			continue
+		}
+
 		msg, err := req.Marshal()
 		if err != nil {
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           req.RelayFromIp,
-				"relayTo":             req.RelayToIp,
+				"relayFrom":           req.RelayFromAddr,
+				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
-				"vpnIp":               newhostinfo.vpnIp}).
+				"vpnAddrs":            newhostinfo.vpnAddrs}).
 				Info("send CreateRelayRequest")
 		}
 	}
@@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		return closeTunnel, hostinfo, nil
 	}
 
-	primary := n.hostMap.Hosts[hostinfo.vpnIp]
+	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
@@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 
-	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
-		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
-		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
-		// The remotes vpn ip is lower than mine. I will not flip.
+	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
+	// vpn addr is static across all tunnels for this host pair so lets
+	// use that to determine if we should consider swapping.
+	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
+		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 	}
 
-	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature())
+	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
+	// settle down.
+	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 	n.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnIp] == primary {
+	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
 		n.hostMap.unlockedMakePrimary(current)
 	}
 	n.hostMap.Unlock()
@@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) {
+	cs := n.intf.pki.getCertState()
+	curCrt := hostinfo.ConnectionState.myCert
+	myCrt := cs.getCertificate(curCrt.Version())
+	if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}
 
-	n.l.WithField("vpnIp", hostinfo.vpnIp).
+	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }

+ 26 - 29
connection_manager_test.go

@@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &dummyCert{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 
@@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 	}
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &dummyCert{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
@@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
 
 // Check if we can disconnect the peer.
@@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 	// Generate keys for CA and peer's cert.
@@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{},
+		v1HandshakeBytes: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.connectionManager = nc
 
 	hostinfo := &HostInfo{
-		vpnIp: vpnIp,
+		vpnAddrs: []netip.Addr{vpnIp},
 		ConnectionState: &ConnectionState{
 			myCert:   &dummyCert{},
 			peerCert: cachedPeerCert,

+ 26 - 21
connection_state.go

@@ -3,6 +3,7 @@ package nebula
 import (
 	"crypto/rand"
 	"encoding/json"
+	"fmt"
 	"sync"
 	"sync/atomic"
 
@@ -26,46 +27,46 @@ type ConnectionState struct {
 	writeLock      sync.Mutex
 }
 
-func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
 	var dhFunc noise.DHFunc
-	switch certState.Certificate.Curve() {
+	switch crt.Curve() {
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
-		if certState.pkcs11Backed {
+		if cs.pkcs11Backed {
 			dhFunc = noiseutil.DHP256PKCS11
 		} else {
 			dhFunc = noiseutil.DHP256
 		}
 	default:
-		l.Errorf("invalid curve: %s", certState.Certificate.Curve())
-		return nil
+		return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
 	}
 
-	var cs noise.CipherSuite
-	if cipher == "chachapoly" {
-		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	var ncs noise.CipherSuite
+	if cs.cipher == "chachapoly" {
+		ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	} else {
-		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
+		ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 
-	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
+	static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
 
 	b := NewBits(ReplayWindow)
-	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
+	// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
 	b.Update(l, 0)
 
 	hs, err := noise.NewHandshakeState(noise.Config{
-		CipherSuite:           cs,
-		Random:                rand.Reader,
-		Pattern:               pattern,
-		Initiator:             initiator,
-		StaticKeypair:         static,
-		PresharedKey:          psk,
-		PresharedKeyPlacement: pskStage,
+		CipherSuite:   ncs,
+		Random:        rand.Reader,
+		Pattern:       pattern,
+		Initiator:     initiator,
+		StaticKeypair: static,
+		//NOTE: These should come from CertState (pki.go) when we finally implement it
+		PresharedKey:          []byte{},
+		PresharedKeyPlacement: 0,
 	})
 	if err != nil {
-		return nil
+		return nil, fmt.Errorf("NewConnectionState: %s", err)
 	}
 
 	// The queue and ready params prevent a counter race that would happen when
@@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		initiator: initiator,
 		window:    b,
-		myCert:    certState.Certificate,
+		myCert:    crt,
 	}
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
 	ci.messageCounter.Add(2)
 
-	return ci
+	return ci, nil
 }
 
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"message_counter": cs.messageCounter.Load(),
 	})
 }
+
+func (cs *ConnectionState) Curve() cert.Curve {
+	return cs.myCert.Curve()
+}

+ 26 - 25
control.go

@@ -19,9 +19,9 @@ import (
 type controlEach func(h *HostInfo)
 
 type controlHostLister interface {
-	QueryVpnIp(vpnIp netip.Addr) *HostInfo
+	QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
-	ForEachVpnIp(each controlEach)
+	ForEachVpnAddr(each controlEach)
 	GetPreferredRanges() []netip.Prefix
 }
 
@@ -37,7 +37,7 @@ type Control struct {
 }
 
 type ControlHostInfo struct {
-	VpnIp                  netip.Addr       `json:"vpnIp"`
+	VpnAddrs               []netip.Addr     `json:"vpnAddrs"`
 	LocalIndex             uint32           `json:"localIndex"`
 	RemoteIndex            uint32           `json:"remoteIndex"`
 	RemoteAddrs            []netip.AddrPort `json:"remoteAddrs"`
@@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
 func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
-	if c.f.myVpnNet.Addr() == vpnIp {
-		return c.f.pki.GetCertState().Certificate.Copy()
+	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
+	if found {
+		// Only returning the default certificate since its impossible
+		// for any other host but ourselves to have more than 1
+		return c.f.pki.getCertState().GetDefaultCertificate().Copy()
 	}
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 		return nil
 	}
@@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
 
 // PrintTunnel creates a new tunnel to the given vpn ip.
 func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 		return nil
 	}
@@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
 	return hi.CopyCache()
 }
 
-// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
+// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
-func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
+func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	if pending {
 		hl = c.f.handshakeManager
@@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 		hl = c.f.hostMap
 	}
 
-	h := hl.QueryVpnIp(vpnIp)
+	h := hl.QueryVpnAddr(vpnAddr)
 	if h == nil {
 		return nil
 	}
@@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 		return nil
 	}
@@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 		return false
 	}
@@ -224,19 +227,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 // CloseAllTunnels is just like CloseTunnel except it goes through and shuts them all down, optionally you can avoid shutting down lighthouse tunnels
 // the int returned is a count of tunnels closed
 func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
-	//TODO: this is probably better as a function in ConnectionManager or HostMap directly
-	lighthouses := c.f.lightHouse.GetLighthouses()
-
 	shutdown := func(h *HostInfo) {
-		if excludeLighthouses {
-			if _, ok := lighthouses[h.vpnIp]; ok {
-				return
-			}
+		if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) {
+			return
 		}
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.closeTunnel(h)
 
-		c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
+		c.l.WithField("vpnAddrs", h.vpnAddrs).WithField("udpAddr", h.remote).
 			Debug("Sending close tunnel message")
 		closed++
 	}
@@ -246,7 +244,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
-		relayingHosts[relayingHost.vpnIp] = relayingHost
+		relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
 	}
 	c.f.hostMap.Unlock()
 
@@ -254,7 +252,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Hosts map
 	c.f.hostMap.Lock()
 	for _, relayHost := range c.f.hostMap.Indexes {
-		if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
+		if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
 			hostInfos = append(hostInfos, relayHost)
 		}
 	}
@@ -274,9 +272,8 @@ func (c *Control) Device() overlay.Device {
 }
 
 func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
-
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp,
+		VpnAddrs:               make([]netip.Addr, len(h.vpnAddrs)),
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
@@ -285,6 +282,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 		CurrentRemote:          h.remote,
 	}
 
+	for i, a := range h.vpnAddrs {
+		chi.VpnAddrs[i] = a
+	}
+
 	if h.ConnectionState != nil {
 		chi.MessageCounter = h.ConnectionState.messageCounter.Load()
 	}
@@ -299,7 +300,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 	hosts := make([]ControlHostInfo, 0)
 	pr := hl.GetPreferredRanges()
-	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+	hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 	})
 	return hosts

+ 17 - 17
control_test.go

@@ -13,13 +13,13 @@ import (
 )
 
 func TestControl_GetHostInfoByVpnIp(t *testing.T) {
-	//TODO: with multiple certificate versions we have a problem with this test
+	//TODO: CERT-V2 with multiple certificate versions we have a problem with this test
 	// Some certs versions have different characteristics and each version implements their own Copy() func
 	// which means this is not a good place to test for exposing memory
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, netip.Prefix{})
+	hm := newHostMap(l)
 	hm.preferredRanges.Store(&[]netip.Prefix{})
 
 	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 
-	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
-	remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+	remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
 
 	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
 	assert.True(t, ok)
@@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
@@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         vpnIp2,
+		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
@@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 	}
 
-	thi := c.GetHostInfoByVpnIp(vpnIp, false)
+	thi := c.GetHostInfoByVpnAddr(vpnIp, false)
 
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  vpnIp,
+		VpnAddrs:               []netip.Addr{vpnIp},
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
@@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	assert.EqualValues(t, &expectedInfo, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(vpnIp2, false)
+		thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
 	})
 }
 

+ 43 - 22
control_tester.go

@@ -6,8 +6,6 @@ package nebula
 import (
 	"net/netip"
 
-	"github.com/slackhq/nebula/cert"
-
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
@@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
 	if toAddr.Addr().Is4() {
-		remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
 	} else {
-		remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
 	}
 }
 
@@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
+	remoteList.unlockedSetRelay(vpnIp, relayVpnIps)
 }
 
 // GetFromTun will pull a packet off the tun side of nebula
@@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
-	//TODO: IPV6-WORK
-	ip := layers.IPv4{
-		Version:  4,
-		TTL:      64,
-		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().Addr().Unmap().AsSlice(),
-		DstIP:    toIp.Unmap().AsSlice(),
+func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
+	serialize := make([]gopacket.SerializableLayer, 0)
+	var netLayer gopacket.NetworkLayer
+	if toAddr.Is6() {
+		if !fromAddr.Is6() {
+			panic("Cant send ipv6 to ipv4")
+		}
+		ip := &layers.IPv6{
+			Version:    6,
+			NextHeader: layers.IPProtocolUDP,
+			SrcIP:      fromAddr.Unmap().AsSlice(),
+			DstIP:      toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
+	} else {
+		if !fromAddr.Is4() {
+			panic("Cant send ipv4 to ipv6")
+		}
+
+		ip := &layers.IPv4{
+			Version:  4,
+			TTL:      64,
+			Protocol: layers.IPProtocolUDP,
+			SrcIP:    fromAddr.Unmap().AsSlice(),
+			DstIP:    toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
 	}
 
 	udp := layers.UDP{
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 	}
-	err := udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(netLayer)
 	if err != nil {
 		panic(err)
 	}
@@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 		ComputeChecksums: true,
 		FixLengths:       true,
 	}
-	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+
+	serialize = append(serialize, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, serialize...)
 	if err != nil {
 		panic(err)
 	}
@@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 
-func (c *Control) GetVpnIp() netip.Addr {
-	return c.f.myVpnNet.Addr()
+func (c *Control) GetVpnAddrs() []netip.Addr {
+	return c.f.myVpnAddrs
 }
 
 func (c *Control) GetUDPAddr() netip.AddrPort {
@@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
 }
 
 func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
+	hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
 	if hostinfo == nil {
 		return false
 	}
@@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 }
 
-func (c *Control) GetCert() cert.Certificate {
-	return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertState() *CertState {
+	return c.f.pki.getCertState()
 }
 
 func (c *Control) ReHandshake(vpnIp netip.Addr) {

+ 74 - 36
dns_server.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 	"sync"
 
+	"github.com/gaissmai/bart"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -21,24 +22,39 @@ var dnsAddr string
 
 type dnsRecords struct {
 	sync.RWMutex
-	dnsMap  map[string]string
-	hostMap *HostMap
+	l               *logrus.Logger
+	dnsMap4         map[string]netip.Addr
+	dnsMap6         map[string]netip.Addr
+	hostMap         *HostMap
+	myVpnAddrsTable *bart.Table[struct{}]
 }
 
-func newDnsRecords(hostMap *HostMap) *dnsRecords {
+func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
 	return &dnsRecords{
-		dnsMap:  make(map[string]string),
-		hostMap: hostMap,
+		l:               l,
+		dnsMap4:         make(map[string]netip.Addr),
+		dnsMap6:         make(map[string]netip.Addr),
+		hostMap:         hostMap,
+		myVpnAddrsTable: cs.myVpnAddrsTable,
 	}
 }
 
-func (d *dnsRecords) Query(data string) string {
+func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
+	data = strings.ToLower(data)
 	d.RLock()
 	defer d.RUnlock()
-	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
-		return r
+	switch q {
+	case dns.TypeA:
+		if r, ok := d.dnsMap4[data]; ok {
+			return r
+		}
+	case dns.TypeAAAA:
+		if r, ok := d.dnsMap6[data]; ok {
+			return r
+		}
 	}
-	return ""
+
+	return netip.Addr{}
 }
 
 func (d *dnsRecords) QueryCert(data string) string {
@@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 	}
 
-	hostinfo := d.hostMap.QueryVpnIp(ip)
+	hostinfo := d.hostMap.QueryVpnAddr(ip)
 	if hostinfo == nil {
 		return ""
 	}
@@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
 	return string(b)
 }
 
-func (d *dnsRecords) Add(host, data string) {
+// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
+func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
+	host = strings.ToLower(host)
 	d.Lock()
 	defer d.Unlock()
-	d.dnsMap[strings.ToLower(host)] = data
+	haveV4 := false
+	haveV6 := false
+	for _, addr := range addresses {
+		if addr.Is4() && !haveV4 {
+			d.dnsMap4[host] = addr
+			haveV4 = true
+		} else if addr.Is6() && !haveV6 {
+			d.dnsMap6[host] = addr
+			haveV6 = true
+		}
+		if haveV4 && haveV6 {
+			break
+		}
+	}
+}
+
+func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
+	a, _, _ := net.SplitHostPort(addr)
+	b, err := netip.ParseAddr(a)
+	if err != nil {
+		return false
+	}
+
+	if b.IsLoopback() {
+		return true
+	}
+
+	_, found := d.myVpnAddrsTable.Lookup(b)
+	return found //if we found it in this table, it's good
 }
 
-func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
+func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 	for _, q := range m.Question {
 		switch q.Qtype {
-		case dns.TypeA:
-			l.Debugf("Query for A %s", q.Name)
-			ip := dnsR.Query(q.Name)
-			if ip != "" {
-				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
+		case dns.TypeA, dns.TypeAAAA:
+			qType := dns.TypeToString[q.Qtype]
+			d.l.Debugf("Query for %s %s", qType, q.Name)
+			ip := d.Query(q.Qtype, q.Name)
+			if ip.IsValid() {
+				rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
 				if err == nil {
 					m.Answer = append(m.Answer, rr)
 				}
 			}
 		case dns.TypeTXT:
-			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b, err := netip.ParseAddr(a)
-			if err != nil {
+			// We only answer these queries from nebula nodes or localhost
+			if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
 				return
 			}
-
-			// We don't answer these queries from non nebula nodes or localhost
-			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
-			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
-				return
-			}
-			l.Debugf("Query for TXT %s", q.Name)
-			ip := dnsR.QueryCert(q.Name)
+			d.l.Debugf("Query for TXT %s", q.Name)
+			ip := d.QueryCert(q.Name)
 			if ip != "" {
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				if err == nil {
@@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 	}
 }
 
-func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
+func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.Compress = false
 
 	switch r.Opcode {
 	case dns.OpcodeQuery:
-		parseQuery(l, m, w)
+		d.parseQuery(m, w)
 	}
 
 	w.WriteMsg(m)
 }
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
-	dnsR = newDnsRecords(hostMap)
+func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
+	dnsR = newDnsRecords(l, cs, hostMap)
 
 	// attach request handler func
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		handleDnsRequest(l, w, r)
-	})
+	dns.HandleFunc(".", dnsR.handleDnsRequest)
 
 	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)

+ 20 - 5
dns_server_test.go

@@ -1,23 +1,38 @@
 package nebula
 
 import (
+	"net/netip"
 	"testing"
 
 	"github.com/miekg/dns"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestParsequery(t *testing.T) {
-	//TODO: This test is basically pointless
+	l := logrus.New()
 	hostMap := &HostMap{}
-	ds := newDnsRecords(hostMap)
-	ds.Add("test.com.com", "1.2.3.4")
+	ds := newDnsRecords(l, &CertState{}, hostMap)
+	addrs := []netip.Addr{
+		netip.MustParseAddr("1.2.3.4"),
+		netip.MustParseAddr("1.2.3.5"),
+		netip.MustParseAddr("fd01::24"),
+		netip.MustParseAddr("fd01::25"),
+	}
+	ds.Add("test.com.com", addrs)
 
-	m := new(dns.Msg)
+	m := &dns.Msg{}
 	m.SetQuestion("test.com.com", dns.TypeA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
 
-	//parseQuery(m)
+	m = &dns.Msg{}
+	m.SetQuestion("test.com.com", dns.TypeAAAA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
 }
 
 func Test_getDnsServerAddr(t *testing.T) {

File diff suppressed because it is too large
+ 304 - 186
e2e/handshakes_test.go


+ 74 - 37
e2e/helpers_test.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net/netip"
 	"os"
+	"strings"
 	"testing"
 	"time"
 
@@ -17,6 +18,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/stretchr/testify/assert"
@@ -26,25 +28,35 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
+func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 
-	vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
-	if err != nil {
-		panic(err)
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
 	}
 
 	var udpAddr netip.AddrPort
-	if vpnIpNet.Addr().Is4() {
-		budpIp := vpnIpNet.Addr().As4()
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
 		budpIp[1] -= 128
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
 	} else {
-		budpIp := vpnIpNet.Addr().As16()
-		budpIp[13] -= 128
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
-	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{})
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
 	caB, err := caCrt.MarshalPEM()
 	if err != nil {
@@ -88,11 +100,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
 	}
 
 	if overrides != nil {
-		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		final := m{}
+		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
 		if err != nil {
 			panic(err)
 		}
-		mc = overrides
+		mc = final
 	}
 
 	cb, err := yaml.Marshal(mc)
@@ -109,7 +126,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
 		panic(err)
 	}
 
-	return control, vpnIpNet, udpAddr, c
+	return control, vpnNetworks, udpAddr, c
 }
 
 type doneCb func()
@@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
-	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
+	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 	// And once more from me to them
-	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
+	controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 
-func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
+	//TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things
+	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
-	hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
+	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
+	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
+	assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
 
 	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
 	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@@ -160,25 +178,36 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
 	// Check that our indexes match
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
-
-	//TODO: Would be nice to assert this memory
-	//checkIndexes := func(name string, hm *HostMap, hi *HostInfo) {
-	//	hBbyIndex := hmA.Indexes[hBinA.localIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by local index in %s", name)
-	//	assert.Equal(t, &hBbyIndex, &hBinA, "%s Indexes map did not point to the right host info", name)
-	//
-	//	//TODO: remote indexes are susceptible to collision
-	//	hBbyRemoteIndex := hmA.RemoteIndexes[hBinA.remoteIndexId]
-	//	assert.NotNil(t, hBbyIndex, "Could not host info by remote index in %s", name)
-	//	assert.Equal(t, &hBbyRemoteIndex, &hBinA, "%s RemoteIndexes did not point to the right host info", name)
-	//}
-	//
-	//// Check hostmap indexes too
-	//checkIndexes("hmA", hmA, hBinA)
-	//checkIndexes("hmB", hmB, hAinB)
 }
 
 func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	if toIp.Is6() {
+		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
+	} else {
+		assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
+	}
+}
+
+func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
+	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+	assert.NotNil(t, v6, "No ipv6 data found")
+
+	assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
+
+	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	assert.NotNil(t, udp, "No udp data found")
+
+	assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
+	assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
+
+	data := packet.ApplicationLayer()
+	assert.NotNil(t, data)
+	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
+}
+
+func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
@@ -197,6 +226,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 
+func getAddrs(ns []netip.Prefix) []netip.Addr {
+	var a []netip.Addr
+	for _, n := range ns {
+		a = append(a, n.Addr())
+	}
+	return a
+}
+
 func NewTestLogger() *logrus.Logger {
 	l := logrus.New()
 

+ 4 - 3
e2e/router/hostmap.go

@@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	var lines []string
 	var globalLines []*edge
 
-	clusterName := strings.Trim(c.GetCert().Name(), " ")
-	clusterVpnIp := c.GetCert().Networks()[0].Addr()
+	crt := c.GetCertState().GetDefaultCertificate()
+	clusterName := strings.Trim(crt.Name(), " ")
+	clusterVpnIp := crt.Networks()[0].Addr()
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 
 	hm := c.GetHostmap()
@@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	for _, idx := range indexes {
 		hi, ok := hm.Indexes[idx]
 		if ok {
-			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
+			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
 			remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			_ = hi

+ 44 - 22
e2e/router/router.go

@@ -10,8 +10,8 @@ import (
 	"os"
 	"path/filepath"
 	"reflect"
+	"regexp"
 	"sort"
-	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 			panic("Duplicate listen address: " + addr.String())
 		}
 
-		r.vpnControls[c.GetVpnIp()] = c
+		for _, vpnAddr := range c.GetVpnAddrs() {
+			r.vpnControls[vpnAddr] = c
+		}
+
 		r.controls[addr] = c
 	}
 
@@ -213,11 +216,11 @@ func (r *R) renderFlow() {
 			continue
 		}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr.String(), ":", "-", 1)
+		sanAddr := normalizeName(addr.String())
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
-			sanAddr, e.packet.from.GetVpnIp(), sanAddr,
+			sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
 		)
 	}
 
@@ -250,9 +253,9 @@ func (r *R) renderFlow() {
 
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.from.GetUDPAddr().String()),
 				line,
-				strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.to.GetUDPAddr().String()),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 		}
@@ -267,6 +270,11 @@ func (r *R) renderFlow() {
 	}
 }
 
+func normalizeName(s string) string {
+	rx := regexp.MustCompile("[\\[\\]\\:]")
+	return rx.ReplaceAllLiteralString(s, "_")
+}
+
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
+		return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
 	})
 
 	s := renderHostmaps(c...)
@@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 		// Nope, lets push the sender along
 		case p := <-udpTx:
 			r.Lock()
-			c := r.getControl(sender.GetUDPAddr(), p.To, p)
+			a := sender.GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 				r.Unlock()
-				panic("No control for udp tx")
+				panic("No control for udp tx " + a.String())
 			}
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			c.InjectUDPPacket(p)
@@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
-			c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
+			a := cm[x].GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 				r.Unlock()
-				panic("No control for udp tx")
+				panic(fmt.Sprintf("No control for udp tx %s", p.To))
 			}
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			c.InjectUDPPacket(p)
@@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
 }
 
 func (r *R) formatUdpPacket(p *packet) string {
-	packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
-	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
-	if v4 == nil {
-		panic("not an ipv4 packet")
+	var packet gopacket.Packet
+	var srcAddr netip.Addr
+
+	packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
+	if packet.ErrorLayer() == nil {
+		v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
+	} else {
+		packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
+		v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
 	}
 
 	from := "unknown"
-	srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
 	if c, ok := r.vpnControls[srcAddr]; ok {
 		from = c.GetUDPAddr().String()
 	}
 
-	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
-	if udp == nil {
+	udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	if udpLayer == nil {
 		panic("not a udp packet")
 	}
 
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
-		udp.SrcPort,
-		udp.DstPort,
+		normalizeName(from),
+		normalizeName(p.to.GetUDPAddr().String()),
+		udpLayer.SrcPort,
+		udpLayer.DstPort,
 		string(data.Payload()),
 	)
 }

+ 12 - 5
examples/config.yml

@@ -13,6 +13,12 @@ pki:
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   #disconnect_invalid: true
 
+  # default_version controls which certificate version is used in handshakes.
+  # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
+  # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
+  # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
+  # default_version: 1
+
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # The syntax is:
@@ -244,7 +250,6 @@ tun:
   # in nebula configuration files. Default false, not reloadable.
   #use_system_route_table: false
 
-# TODO
 # Configure logging level
 logging:
   # panic, fatal, error, warning, info, or debug. Default is info and is reloadable.
@@ -336,10 +341,12 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
-  #      Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
-  #      if `default_local_cidr_any` is false, otherwise its `any`.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
+  #     If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
+  #     Otherwise the default is any vpn network assigned to via the certificate.
+  #     `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
+  #     If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
 

+ 65 - 56
firewall.go

@@ -22,7 +22,7 @@ import (
 )
 
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
 }
 
 type conn struct {
@@ -51,9 +51,12 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 
-	// Used to ensure we don't emit local packets for ips we don't own
-	localIps          *bart.Table[struct{}]
-	assignedCIDR      netip.Prefix
+	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
+	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
+	routableNetworks *bart.Table[struct{}]
+
+	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
+	assignedNetworks  []netip.Prefix
 	hasUnsafeNetworks bool
 
 	rules        string
@@ -67,9 +70,9 @@ type Firewall struct {
 }
 
 type firewallMetrics struct {
-	droppedLocalIP  metrics.Counter
-	droppedRemoteIP metrics.Counter
-	droppedNoRule   metrics.Counter
+	droppedLocalAddr  metrics.Counter
+	droppedRemoteAddr metrics.Counter
+	droppedNoRule     metrics.Counter
 }
 
 type FirewallConntrack struct {
@@ -126,84 +129,87 @@ type firewallLocalCIDR struct {
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
+// The certificate provided should be the highest version loaded in memory.
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
 	//TODO: error on 0 duration
-	var min, max time.Duration
+	var tmin, tmax time.Duration
 
 	if tcpTimeout < UDPTimeout {
-		min = tcpTimeout
-		max = UDPTimeout
+		tmin = tcpTimeout
+		tmax = UDPTimeout
 	} else {
-		min = UDPTimeout
-		max = tcpTimeout
+		tmin = UDPTimeout
+		tmax = tcpTimeout
 	}
 
-	if defaultTimeout < min {
-		min = defaultTimeout
-	} else if defaultTimeout > max {
-		max = defaultTimeout
+	if defaultTimeout < tmin {
+		tmin = defaultTimeout
+	} else if defaultTimeout > tmax {
+		tmax = defaultTimeout
 	}
 
-	localIps := new(bart.Table[struct{}])
-	var assignedCIDR netip.Prefix
-	var assignedSet bool
+	routableNetworks := new(bart.Table[struct{}])
+	var assignedNetworks []netip.Prefix
 	for _, network := range c.Networks() {
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
-		localIps.Insert(nprefix, struct{}{})
-
-		if !assignedSet {
-			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = nprefix
-			assignedSet = true
-		}
+		routableNetworks.Insert(nprefix, struct{}{})
+		assignedNetworks = append(assignedNetworks, network)
 	}
 
 	hasUnsafeNetworks := false
 	for _, n := range c.UnsafeNetworks() {
-		localIps.Insert(n, struct{}{})
+		routableNetworks.Insert(n, struct{}{})
 		hasUnsafeNetworks = true
 	}
 
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 			Conns:      make(map[firewall.Packet]*conn),
-			TimerWheel: NewTimerWheel[firewall.Packet](min, max),
+			TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
 		},
 		InRules:           newFirewallTable(),
 		OutRules:          newFirewallTable(),
 		TCPTimeout:        tcpTimeout,
 		UDPTimeout:        UDPTimeout,
 		DefaultTimeout:    defaultTimeout,
-		localIps:          localIps,
-		assignedCIDR:      assignedCIDR,
+		routableNetworks:  routableNetworks,
+		assignedNetworks:  assignedNetworks,
 		hasUnsafeNetworks: hasUnsafeNetworks,
 		l:                 l,
 
 		incomingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
 		},
 		outgoingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
 		},
 	}
 }
 
-func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
+	certificate := cs.getCertificate(cert.Version2)
+	if certificate == nil {
+		certificate = cs.getCertificate(cert.Version1)
+	}
+
+	if certificate == nil {
+		panic("No certificate available to reconfigure the firewall")
+	}
+
 	fw := NewFirewall(
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
-		nc,
+		certificate,
 		//TODO: max_connections
 	)
 
-	//TODO: Flip to false after v1.9 release
-	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
+	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
 
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	switch inboundAction {
@@ -283,7 +289,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		fp = ft.TCP
 	case firewall.ProtoUDP:
 		fp = ft.UDP
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		fp = ft.ICMP
 	case firewall.ProtoAny:
 		fp = ft.AnyProto
@@ -424,26 +430,24 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 
 	// Make sure remote address matches nebula certificate
-	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-		_, ok := remoteCidr.Lookup(fp.RemoteIP)
+	if h.networks != nil {
+		_, ok := h.networks.Lookup(fp.RemoteAddr)
 		if !ok {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.vpnIp {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-	_, ok := f.localIps.Lookup(fp.LocalIP)
+	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
 	if !ok {
-		f.metrics(incoming).droppedLocalIP.Inc(1)
+		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 	}
 
@@ -629,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 		}
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 		}
@@ -859,9 +863,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
 	}
 
 	matched := false
-	prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
+	prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
 	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
-		if prefix.Contains(p.RemoteIP) && val.match(p, c) {
+		if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
 			matched = true
 			return false
 		}
@@ -877,9 +881,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 			return nil
 		}
 
-		localIp = f.assignedCIDR
+		for _, network := range f.assignedNetworks {
+			flc.LocalCIDR.Insert(network, struct{}{})
+		}
+		return nil
+
 	} else if localIp.Bits() == 0 {
 		flc.Any = true
+		return nil
 	}
 
 	flc.LocalCIDR.Insert(localIp, struct{}{})
@@ -895,7 +904,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
 		return true
 	}
 
-	_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
 	return ok
 }
 

+ 11 - 10
firewall/packet.go

@@ -9,18 +9,19 @@ import (
 type m map[string]interface{}
 
 const (
-	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	ProtoTCP  = 6
-	ProtoUDP  = 17
-	ProtoICMP = 1
+	ProtoAny    = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP    = 6
+	ProtoUDP    = 17
+	ProtoICMP   = 1
+	ProtoICMPv6 = 58
 
 	PortAny      = 0  // Special value for matching `port: any`
 	PortFragment = -1 // Special value for matching `port: fragment`
 )
 
 type Packet struct {
-	LocalIP    netip.Addr
-	RemoteIP   netip.Addr
+	LocalAddr  netip.Addr
+	RemoteAddr netip.Addr
 	LocalPort  uint16
 	RemotePort uint16
 	Protocol   uint8
@@ -29,8 +30,8 @@ type Packet struct {
 
 func (fp *Packet) Copy() *Packet {
 	return &Packet{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
+		LocalAddr:  fp.LocalAddr,
+		RemoteAddr: fp.RemoteAddr,
 		LocalPort:  fp.LocalPort,
 		RemotePort: fp.RemotePort,
 		Protocol:   fp.Protocol,
@@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 	}
 	return json.Marshal(m{
-		"LocalIP":    fp.LocalIP.String(),
-		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalAddr":  fp.LocalAddr.String(),
+		"RemoteAddr": fp.RemoteAddr.String(),
 		"LocalPort":  fp.LocalPort,
 		"RemotePort": fp.RemotePort,
 		"Protocol":   proto,

+ 41 - 38
firewall_test.go

@@ -13,6 +13,7 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestNewFirewall(t *testing.T) {
@@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
@@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) {
 				InvertedGroups: map[string]struct{}{"default-group": {}},
 			},
 		},
-		vpnIp: netip.MustParseAddr("1.2.3.4"),
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(c.networks, c.unsafeNetworks)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 	// test remote mismatch
-	oldRemote := p.RemoteIP
-	p.RemoteIP = netip.MustParseAddr("1.2.3.10")
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
-	p.RemoteIP = oldRemote
+	p.RemoteAddr = oldRemote
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		}
 		ip := netip.MustParsePrefix("9.254.254.254/32")
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 	})
 
@@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
 
@@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			InvertedGroups: map[string]struct{}{"good-group": {}},
 		}
 		for n := 0; n < b.N; n++ {
-			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
 
@@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
@@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.CreateRemoteCIDR(c.Certificate)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	c1 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -341,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) {
 		InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
 	}
 	h1 := HostInfo{
+		vpnAddrs: []netip.Addr{network.Addr()},
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
 	}
-	h1.CreateRemoteCIDR(c1.Certificate)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -364,8 +366,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  1,
 		RemotePort: 1,
 		Protocol:   firewall.ProtoUDP,
@@ -391,9 +393,9 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h1.CreateRemoteCIDR(c1.Certificate)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	c2 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -406,9 +408,9 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h2.CreateRemoteCIDR(c2.Certificate)
+	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
 
 	c3 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -421,9 +423,9 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h3.CreateRemoteCIDR(c3.Certificate)
+	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -446,8 +448,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
@@ -468,9 +470,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.CreateRemoteCIDR(c.Certificate)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -574,8 +576,6 @@ func BenchmarkLookup(b *testing.B) {
 			ml(m, a)
 		}
 	})
-
-	//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
 }
 
 func Test_parsePort(t *testing.T) {
@@ -622,55 +622,58 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	l := test.NewLogger()
 	// Test a bad rule definition
 	c := &dummyCert{}
+	cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
+	require.NoError(t, err)
+
 	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
-	_, err := NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 	// Test both port and code
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 	// Test code/port error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 	// Test proto error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 	// Test cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 

+ 0 - 1
go.mod

@@ -21,7 +21,6 @@ require (
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
-	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stretchr/testify v1.9.0
 	github.com/vishvananda/netlink v1.3.0

+ 0 - 2
go.sum

@@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

+ 196 - 76
handshake_ix.go

@@ -2,10 +2,12 @@ package nebula
 
 import (
 	"net/netip"
+	"slices"
 	"time"
 
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 )
 
@@ -16,30 +18,60 @@ import (
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 	}
 
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
-	hh.hostinfo.ConnectionState = ci
+	// If we're connecting to a v6 address we must use a v2 cert
+	cs := f.pki.getCertState()
+	v := cs.defaultVersion
+	for _, a := range hh.hostinfo.vpnAddrs {
+		if a.Is6() {
+			v = cert.Version2
+			break
+		}
+	}
 
-	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hh.hostinfo.localIndexId,
-		Time:           uint64(time.Now().UnixNano()),
-		Cert:           certState.RawCertificateNoKey,
+	crt := cs.getCertificate(v)
+	if crt == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate is available")
+		return false
 	}
 
-	hsBytes := []byte{}
+	crtHs := cs.getHandshakeBytes(v)
+	if crtHs == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Failed to create connection state")
+		return false
+	}
+	hh.hostinfo.ConnectionState = ci
 
 	hs := &NebulaHandshake{
-		Details: hsProto,
+		Details: &NebulaHandshakeDetails{
+			InitiatorIndex: hh.hostinfo.localIndexId,
+			Time:           uint64(time.Now().UnixNano()),
+			Cert:           crtHs,
+			CertVersion:    uint32(v),
+		},
 	}
-	hsBytes, err = hs.Marshal()
 
+	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 	}
@@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 	}
@@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 }
 
 func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
+	cs := f.pki.getCertState()
+	crt := cs.GetDefaultCertificate()
+	if crt == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", cs.defaultVersion).
+			Error("Unable to handshake with host because no certificate is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to create connection state")
+		return
+	}
+
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to call noise.ReadMessage")
 		return
 	}
 
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
-	/*
-		l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
-	*/
 	if err != nil || hs.Details == nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed unmarshal handshake message")
 		return
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 	}
 
+	if remoteCert.Certificate.Version() != ci.myCert.Version() {
+		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
+		rc := cs.getCertificate(remoteCert.Certificate.Version())
+		if rc == nil {
+			f.l.WithError(err).WithField("udpAddr", addr).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
+				Info("Unable to handshake with host due to missing certificate version")
+			return
+		}
+
+		// Record the certificate we are actually using
+		ci.myCert = rc
+	}
+
 	if len(remoteCert.Certificate.Networks()) == 0 {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@@ -111,30 +171,54 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 	}
 
-	vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
 	certName := remoteCert.Certificate.Name()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
 
-	if vpnIp == f.myVpnNet.Addr() {
-		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	for _, network := range remoteCert.Certificate.Networks() {
+		vpnAddr := network.Addr()
+		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
+		if found {
+			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("fingerprint", fingerprint).
+				WithField("issuer", issuer).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			return
+		}
+
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
 		return
 	}
 
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		// addr can be invalid when the tunnel is being relayed.
+		// We only want to apply the remote allow list for direct tunnels here
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
 
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -146,17 +230,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		vpnIp:             vpnIp,
+		vpnAddrs:          vpnAddrs,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}
 
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
@@ -165,13 +249,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		Info("Handshake message received")
 
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = certState.RawCertificateNoKey
+	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
+	if hs.Details.Cert == nil {
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("certVersion", ci.myCert.Version()).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return
+	}
+
+	hs.Details.CertVersion = uint32(ci.myCert.Version())
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
 	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -182,14 +279,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -213,9 +310,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
-	hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -225,7 +322,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
-				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 
 			msg = existing.HandshakePacket[2]
@@ -233,11 +330,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			if addr.IsValid() {
 				err := f.outside.WriteTo(msg, addr)
 				if err != nil {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 				} else {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 				}
@@ -247,16 +344,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 				}
-				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
+				f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 				return
 			}
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@@ -267,23 +364,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				Info("Handshake too old")
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -299,7 +396,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if addr.IsValid() {
 		err = f.outside.WriteTo(msg, addr)
 		if err != nil {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -307,7 +404,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake")
 		} else {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
@@ -320,9 +417,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 		}
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+		// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
+		// it's correctly marked as working.
+		via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
@@ -349,8 +449,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	hostinfo := hh.hostinfo
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
 	}
@@ -358,7 +459,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -367,7 +468,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 
@@ -379,16 +480,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
 		return true
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
-		e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 
 		if f.l.Level > logrus.DebugLevel {
@@ -409,11 +510,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			e = e.WithField("cert", remoteCert)
 		}
 
-		e.Info("Invalid vpn ip from host")
+		e.Info("Empty networks from host")
 		return true
 	}
 
-	vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
+	vpnNetworks := remoteCert.Certificate.Networks()
 	certName := remoteCert.Certificate.Name()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
@@ -430,12 +531,34 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	if addr.IsValid() {
 		hostinfo.SetRemote(addr)
 	} else {
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+	}
+
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	for _, network := range vpnNetworks {
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		vpnAddr := network.Addr()
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
+		return true
 	}
 
 	// Ensure the right host responded
-	if vpnIp != hostinfo.vpnIp {
-		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
+	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
@@ -444,14 +567,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
-			//TODO: this doesnt know if its being added or is being used for caching a packet
+		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes.BlockRemote(addr)
 
 			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
-				WithField("vpnIp", newHH.hostinfo.vpnIp).
+				WithField("vpnNetworks", vpnNetworks).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				Info("Blocked addresses for handshakes")
 
@@ -459,11 +581,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			newHH.packetStore = hh.packetStore
 			hh.packetStore = []*cachedPacket{}
 
-			// Get the correct remote list for the host we did handshake with
-			hostinfo.SetRemote(addr)
-			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
-			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-			hostinfo.vpnIp = vpnIp
+			// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnAddrs = vpnAddrs
 			f.sendCloseTunnel(hostinfo)
 		})
 
@@ -474,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
@@ -485,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		Info("Handshake message received")
 
 	// Build up the radix for the firewall if we have subnets in the cert
-	hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
+	hostinfo.vpnAddrs = vpnAddrs
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
-	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
+	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 

+ 180 - 123
handshake_manager.go

@@ -13,6 +13,7 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 )
@@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
 	}
 }
 
-func (c *HandshakeManager) Run(ctx context.Context) {
-	clockSource := time.NewTicker(c.config.tryInterval)
+func (hm *HandshakeManager) Run(ctx context.Context) {
+	clockSource := time.NewTicker(hm.config.tryInterval)
 	defer clockSource.Stop()
 
 	for {
 		select {
 		case <-ctx.Done():
 			return
-		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, true)
+		case vpnIP := <-hm.trigger:
+			hm.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now)
+			hm.NextOutboundHandshakeTimerTick(now)
 		}
 	}
 }
@@ -137,7 +138,7 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
 	if addr.IsValid() {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
@@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
 	}
 }
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
-	c.OutboundHandshakeTimer.Advance(now)
+func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
+	hm.OutboundHandshakeTimer.Advance(now)
 	for {
-		vpnIp, has := c.OutboundHandshakeTimer.Purge()
+		vpnIp, has := hm.OutboundHandshakeTimer.Purge()
 		if !has {
 			break
 		}
-		c.handleOutbound(vpnIp, false)
+		hm.handleOutbound(vpnIp, false)
 	}
 }
 
@@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
 	}
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@@ -223,7 +224,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 
 	hh.lastRemotes = remotes
 
-	// TODO: this will generate a load of queries for hosts with only 1 ip
+	// This will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
 	if len(remotes) <= 1 && hh.counter == 5 {
@@ -267,59 +268,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
+			// Don't relay to myself
+			if relay == vpnIp {
 				continue
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+
+			// Don't relay through the host I'm trying to connect to
+			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
+			if found {
+				continue
+			}
+
+			relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
 			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hm.f.Handshake(relay)
 				continue
 			}
-			// Check the relay HostInfo to see if we already established a relay through it
-			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
-				switch existingRelay.State {
-				case Established:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
-				case Requested:
-					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
-
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
-					// Re-send the CreateRelay request, in case the previous one was lost.
-					m := NebulaControl{
-						Type:                NebulaControl_CreateRelayRequest,
-						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
-					}
-					msg, err := m.Marshal()
-					if err != nil {
-						hostinfo.logger(hm.l).
-							WithError(err).
-							Error("Failed to marshal Control message to create relay")
-					} else {
-						// This must send over the hostinfo, not over hm.Hosts[ip]
-						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
-							"relayTo":             vpnIp,
-							"initiatorRelayIndex": existingRelay.LocalIndex,
-							"relay":               relay}).
-							Info("send CreateRelayRequest")
-					}
-				default:
-					hostinfo.logger(hm.l).
-						WithField("vpnIp", vpnIp).
-						WithField("state", existingRelay.State).
-						WithField("relay", relayHostInfo.vpnIp).
-						Errorf("Relay unexpected state")
-				}
-			} else {
+			// Check the relay HostInfo to see if we already established a relay through
+			existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp)
+			if !ok {
 				// No relays exist or requested yet.
 				if relayHostInfo.remote.IsValid() {
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
@@ -327,16 +295,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
+
+					switch relayHostInfo.GetCert().Certificate.Version() {
+					case cert.Version1:
+						if !hm.f.myVpnAddrs[0].Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+							continue
+						}
+
+						if !vpnIp.Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+							continue
+						}
+
+						b := hm.f.myVpnAddrs[0].As4()
+						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+						b = vpnIp.As4()
+						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					case cert.Version2:
+						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					default:
+						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						continue
+					}
+
 					msg, err := m.Marshal()
 					if err != nil {
 						hostinfo.logger(hm.l).
@@ -345,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
+							"relayFrom":           hm.f.myVpnAddrs[0],
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
 							Info("send CreateRelayRequest")
 					}
 				}
+				continue
+			}
+
+			switch existingRelay.State {
+			case Established:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+				hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+			case Disestablished:
+				// Mark this relay as 'requested'
+				relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
+				fallthrough
+			case Requested:
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+				// Re-send the CreateRelay request, in case the previous one was lost.
+				m := NebulaControl{
+					Type:                NebulaControl_CreateRelayRequest,
+					InitiatorRelayIndex: existingRelay.LocalIndex,
+				}
+
+				switch relayHostInfo.GetCert().Certificate.Version() {
+				case cert.Version1:
+					if !hm.f.myVpnAddrs[0].Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+						continue
+					}
+
+					if !vpnIp.Is4() {
+						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+						continue
+					}
+
+					b := hm.f.myVpnAddrs[0].As4()
+					m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+					b = vpnIp.As4()
+					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+				case cert.Version2:
+					m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+				default:
+					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+					continue
+				}
+				msg, err := m.Marshal()
+				if err != nil {
+					hostinfo.logger(hm.l).
+						WithError(err).
+						Error("Failed to marshal Control message to create relay")
+				} else {
+					// This must send over the hostinfo, not over hm.Hosts[ip]
+					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+					hm.l.WithFields(logrus.Fields{
+						"relayFrom":           hm.f.myVpnAddrs[0],
+						"relayTo":             vpnIp,
+						"initiatorRelayIndex": existingRelay.LocalIndex,
+						"relay":               relay}).
+						Info("send CreateRelayRequest")
+				}
+			case PeerRequested:
+				// PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case.
+				fallthrough
+			default:
+				hostinfo.logger(hm.l).
+					WithField("vpnIp", vpnIp).
+					WithField("state", existingRelay.State).
+					WithField("relay", relay).
+					Errorf("Relay unexpected state")
+
 			}
 		}
 	}
@@ -381,10 +435,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 }
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 
-	if hh, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnAddr]; ok {
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
 			cacheCb(hh)
@@ -394,12 +448,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 	}
 
 	hostinfo := &HostInfo{
-		vpnIp:           vpnIp,
+		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 	}
 
@@ -407,9 +461,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 		hostinfo:  hostinfo,
 		startTime: time.Now(),
 	}
-	hm.vpnIps[vpnIp] = hh
+	hm.vpnIps[vpnAddr] = hh
 	hm.metricInitiated.Inc(1)
-	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
+	hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
 
 	if cacheCb != nil {
 		cacheCb(hh)
@@ -417,21 +471,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// We can trigger the handshake right now
-	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
 	if !doTrigger {
 		// Add any calculated remotes, and trigger early handshake if one found
-		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
 	}
 
 	if doTrigger {
 		select {
-		case hm.trigger <- vpnIp:
+		case hm.trigger <- vpnAddr:
 		default:
 		}
 	}
 
 	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp)
+	hm.lightHouse.QueryServer(vpnAddr)
 	return hostinfo
 }
 
@@ -452,14 +506,14 @@ var (
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
-	c.Lock()
-	defer c.Unlock()
+func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
 	if found && existingHostInfo != nil {
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
@@ -476,31 +530,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			return existingHostInfo, ErrExistingHostInfo
 		}
 
-		existingHostInfo.logger(c.l).Info("Taking new handshake")
+		existingHostInfo.logger(hm.l).Info("Taking new handshake")
 	}
 
-	existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
 	if found {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
 	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 	}
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+		hostinfo.logger(hm.l).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 	}
 
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 }
 
@@ -518,7 +572,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(hm.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 	}
 
@@ -555,31 +609,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 }
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	c.Lock()
-	defer c.Unlock()
-	c.unlockedDeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
 }
 
-func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	delete(c.vpnIps, hostinfo.vpnIp)
-	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
+func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	for _, addr := range hostinfo.vpnAddrs {
+		delete(hm.vpnIps, addr)
 	}
 
-	delete(c.indexes, hostinfo.localIndexId)
-	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HandshakeHostInfo{}
+	if len(hm.vpnIps) == 0 {
+		hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+	delete(hm.indexes, hostinfo.localIndexId)
+	if len(hm.indexes) == 0 {
+		hm.indexes = map[uint32]*HandshakeHostInfo{}
+	}
+
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Pending hostmap hostInfo deleted")
 	}
 }
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
+func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 		return hh.hostinfo
@@ -608,37 +665,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 }
 
-func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
-	return c.mainHostMap.GetPreferredRanges()
+func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
+	return hm.mainHostMap.GetPreferredRanges()
 }
 
-func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
-	for _, v := range c.vpnIps {
+	for _, v := range hm.vpnIps {
 		f(v.hostinfo)
 	}
 }
 
-func (c *HandshakeManager) ForEachIndex(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
-	for _, v := range c.indexes {
+	for _, v := range hm.indexes {
 		f(v.hostinfo)
 	}
 }
 
-func (c *HandshakeManager) EmitStats() {
-	c.RLock()
-	hostLen := len(c.vpnIps)
-	indexLen := len(c.indexes)
-	c.RUnlock()
+func (hm *HandshakeManager) EmitStats() {
+	hm.RLock()
+	hostLen := len(hm.vpnIps)
+	indexLen := len(hm.indexes)
+	hm.RUnlock()
 
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
-	c.mainHostMap.EmitStats()
+	hm.mainHostMap.EmitStats()
 }
 
 // Utility functions below

+ 19 - 11
handshake_manager_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
@@ -13,21 +14,20 @@ import (
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	ip := netip.MustParseAddr("172.1.1.2")
 
 	preferredRanges := []netip.Prefix{localrange}
-	mainHM := newHostMap(l, vpncidr)
+	mainHM := newHostMap(l)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
 	lh := newTestLighthouse()
 
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 
-	i.remotes = NewRemoteList(nil)
+	i.remotes = NewRemoteList([]netip.Addr{}, nil)
 
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
@@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 type mockEncWriter struct {
 }
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
 	return
 }
 
-func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
 	return
 }
 
-func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
 	return
 }
 
-func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
+func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
+
+func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
+	return nil
+}
+
+func (mw *mockEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: cert.Version2}
+}

+ 169 - 94
hostmap.go

@@ -35,6 +35,7 @@ const (
 	Requested = iota
 	PeerRequested
 	Established
+	Disestablished
 )
 
 const (
@@ -48,7 +49,7 @@ type Relay struct {
 	State       int
 	LocalIndex  uint32
 	RemoteIndex uint32
-	PeerIp      netip.Addr
+	PeerAddr    netip.Addr
 }
 
 type HostMap struct {
@@ -58,7 +59,6 @@ type HostMap struct {
 	RemoteIndexes   map[uint32]*HostInfo
 	Hosts           map[netip.Addr]*HostInfo
 	preferredRanges atomic.Pointer[[]netip.Prefix]
-	vpnCIDR         netip.Prefix
 	l               *logrus.Logger
 }
 
@@ -68,9 +68,12 @@ type HostMap struct {
 type RelayState struct {
 	sync.RWMutex
 
-	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
+	relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	// For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data,
+	// modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with
+	// the RelayState Lock held)
+	relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx  map[uint32]*Relay     // Maps a local index to some Relay info
 }
 
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
@@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	delete(rs.relays, ip)
 }
 
+func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByAddr[vpnIp]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
+func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) {
+	rs.Lock()
+	defer rs.Unlock()
+	if r, ok := rs.relayForByIdx[idx]; ok {
+		newRelay := *r
+		newRelay.State = state
+		rs.relayForByAddr[newRelay.PeerAddr] = &newRelay
+		rs.relayForByIdx[newRelay.LocalIndex] = &newRelay
+	}
+}
+
 func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	rs.RLock()
 	defer rs.RUnlock()
@@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 }
 
-func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[ip]
+	r, ok := rs.relayForByAddr[addr]
 	return r, ok
 }
 
@@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
 func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	defer rs.RUnlock()
-	currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
-	for relayIp := range rs.relayForByIp {
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
+	for relayIp := range rs.relayForByAddr {
 		currentRelays = append(currentRelays, relayIp)
 	}
 	return currentRelays
@@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	defer rs.Unlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	if !ok {
 		return false
 	}
@@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return true
 }
 
@@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return &newRelay, true
 }
 
 func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	return r, ok
 }
 
@@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	defer rs.Unlock()
-	rs.relayForByIp[ip] = r
+	rs.relayForByAddr[ip] = r
 	rs.relayForByIdx[idx] = r
 }
 
@@ -190,10 +215,16 @@ type HostInfo struct {
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	localIndexId    uint32
-	vpnIp           netip.Addr
-	recvError       atomic.Uint32
-	remoteCidr      *bart.Table[struct{}]
-	relayState      RelayState
+
+	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
+	// The host may have other vpn addresses that are outside our
+	// vpn networks but were removed because they are not usable
+	vpnAddrs  []netip.Addr
+	recvError atomic.Uint32
+
+	// networks are both all vpn and unsafe networks assigned to this host
+	networks   *bart.Table[struct{}]
+	relayState RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
 	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
@@ -241,28 +272,26 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
-	hm := newHostMap(l, vpnCIDR)
+func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
+	hm := newHostMap(l)
 
 	hm.reload(c, true)
 	c.RegisterReloadCallback(func(c *config.C) {
 		hm.reload(c, false)
 	})
 
-	l.WithField("network", hm.vpnCIDR.String()).
-		WithField("preferredRanges", hm.GetPreferredRanges()).
+	l.WithField("preferredRanges", hm.GetPreferredRanges()).
 		Info("Main HostMap created")
 
 	return hm
 }
 
-func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
+func newHostMap(l *logrus.Logger) *HostMap {
 	return &HostMap{
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		Hosts:         map[netip.Addr]*HostInfo{},
-		vpnCIDR:       vpnCIDR,
 		l:             l,
 	}
 }
@@ -305,17 +334,6 @@ func (hm *HostMap) EmitStats() {
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 
-func (hm *HostMap) RemoveRelay(localIdx uint32) {
-	hm.Lock()
-	_, ok := hm.Relays[localIdx]
-	if !ok {
-		hm.Unlock()
-		return
-	}
-	delete(hm.Relays, localIdx)
-	hm.Unlock()
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
@@ -335,48 +353,73 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 }
 
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
-	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	// Get the current primary, if it exists
+	oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
+
+	// Every address in the hostinfo gets elevated to primary
+	for _, vpnAddr := range hostinfo.vpnAddrs {
+		//NOTE: It is possible that we leave a dangling hostinfo here but connection manager works on
+		// indexes so it should be fine.
+		hm.Hosts[vpnAddr] = hostinfo
+	}
+
+	// If we are already primary then we won't bother re-linking
 	if oldHostinfo == hostinfo {
 		return
 	}
 
+	// Unlink this hostinfo
 	if hostinfo.prev != nil {
 		hostinfo.prev.next = hostinfo.next
 	}
-
 	if hostinfo.next != nil {
 		hostinfo.next.prev = hostinfo.prev
 	}
 
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
+	// If there wasn't a previous primary then clear out any links
 	if oldHostinfo == nil {
+		hostinfo.next = nil
+		hostinfo.prev = nil
 		return
 	}
 
+	// Relink the hostinfo as primary
 	hostinfo.next = oldHostinfo
 	oldHostinfo.prev = hostinfo
 	hostinfo.prev = nil
 }
 
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	for _, addr := range hostinfo.vpnAddrs {
+		h := hm.Hosts[addr]
+		for h != nil {
+			if h == hostinfo {
+				hm.unlockedInnerDeleteHostInfo(h, addr)
+			}
+			h = h.next
+		}
+	}
+}
+
+func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
+	primary, ok := hm.Hosts[addr]
+	isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil
 	if ok && primary == hostinfo {
-		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
-		delete(hm.Hosts, hostinfo.vpnIp)
+		// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, addr)
 		if len(hm.Hosts) == 0 {
 			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 
 		if hostinfo.next != nil {
-			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
-			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
+			hm.Hosts[addr] = hostinfo.next
 			// It is primary, there is no previous hostinfo now
 			hostinfo.next.prev = nil
 		}
 
 	} else {
-		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		// Relink if we were in the middle of multiple hostinfos for this vpn addr
 		if hostinfo.prev != nil {
 			hostinfo.prev.next = hostinfo.next
 		}
@@ -406,10 +449,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
 
+	if isLastHostinfo {
+		// I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next
+		// hops as 'Requested' so that new relay tunnels are created in the future.
+		hm.unlockedDisestablishVpnAddrRelayFor(hostinfo)
+	}
+	// Clean up any local relay indexes for which I am acting as a relay hop
 	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
 		delete(hm.Relays, localRelayIdx)
 	}
@@ -448,11 +497,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 }
 
-func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
-	return hm.queryVpnIp(vpnIp, nil)
+func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
+	return hm.queryVpnAddr(vpnIp, nil)
 }
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -460,17 +509,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
 	if !ok {
 		return nil, nil, errors.New("unable to find host")
 	}
+
 	for h != nil {
-		r, ok := h.relayState.QueryRelayForByIp(targetIp)
-		if ok && r.State == Established {
-			return h, r, nil
+		for _, targetIp := range targetIps {
+			r, ok := h.relayState.QueryRelayForByIp(targetIp)
+			if ok && r.State == Established {
+				return h, r, nil
+			}
 		}
 		h = h.next
 	}
+
 	return nil, nil, errors.New("unable to find host with relay")
 }
 
-func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
+	for _, relayHostIp := range hi.relayState.CopyRelayIps() {
+		if h, ok := hm.Hosts[relayHostIp]; ok {
+			for h != nil {
+				h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+				h = h.next
+			}
+		}
+	}
+	for _, rs := range hi.relayState.CopyAllRelayFor() {
+		if rs.Type == ForwardingType {
+			if h, ok := hm.Hosts[rs.PeerAddr]; ok {
+				for h != nil {
+					h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished)
+					h = h.next
+				}
+			}
+		}
+	}
+}
+
+func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -491,25 +565,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
-		dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String())
+		dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
 	}
-
-	existing := hm.Hosts[hostinfo.vpnIp]
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
-	if existing != nil {
-		hostinfo.next = existing
-		existing.prev = hostinfo
+	for _, addr := range hostinfo.vpnAddrs {
+		hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
 	}
 
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
+		hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
 			Debug("Hostmap vpnIp added")
 	}
+}
+
+func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
+	existing := hm.Hosts[vpnAddr]
+	hm.Hosts[vpnAddr] = hostinfo
+
+	if existing != nil && existing != hostinfo {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
 
 	i := 1
 	check := hostinfo
@@ -527,7 +606,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	return *hm.preferredRanges.Load()
 }
 
-func (hm *HostMap) ForEachVpnIp(f controlEach) {
+func (hm *HostMap) ForEachVpnAddr(f controlEach) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -581,7 +660,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
 		}
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp)
+		ifce.lightHouse.QueryServer(i.vpnAddrs[0])
 	}
 }
 
@@ -596,7 +675,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	if i.remote != remote {
 		i.remote = remote
-		i.remotes.LearnRemote(i.vpnIp, remote)
+		i.remotes.LearnRemote(i.vpnAddrs[0], remote)
 	}
 }
 
@@ -647,21 +726,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 }
 
-func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) {
-	if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
+func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
+	if len(networks) == 1 && len(unsafeNetworks) == 0 {
 		// Simple case, no CIDRTree needed
 		return
 	}
 
-	remoteCidr := new(bart.Table[struct{}])
-	for _, network := range c.Networks() {
-		remoteCidr.Insert(network, struct{}{})
+	i.networks = new(bart.Table[struct{}])
+	for _, network := range networks {
+		i.networks.Insert(network, struct{}{})
 	}
 
-	for _, network := range c.UnsafeNetworks() {
-		remoteCidr.Insert(network, struct{}{})
+	for _, network := range unsafeNetworks {
+		i.networks.Insert(network, struct{}{})
 	}
-	i.remoteCidr = remoteCidr
 }
 
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@@ -669,7 +747,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 	}
 
-	li := l.WithField("vpnIp", i.vpnIp).
+	li := l.WithField("vpnAddrs", i.vpnAddrs).
 		WithField("localIndex", i.localIndexId).
 		WithField("remoteIndex", i.remoteIndexId)
 
@@ -684,9 +762,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 // Utility functions
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
+func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
-	var ips []netip.Addr
+	var finalAddrs []netip.Addr
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
@@ -698,39 +776,36 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 			continue
 		}
 		addrs, _ := i.Addrs()
-		for _, addr := range addrs {
-			var ip net.IP
-			switch v := addr.(type) {
+		for _, rawAddr := range addrs {
+			var addr netip.Addr
+			switch v := rawAddr.(type) {
 			case *net.IPNet:
 				//continue
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			case *net.IPAddr:
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			}
 
-			nip, ok := netip.AddrFromSlice(ip)
-			if !ok {
+			if !addr.IsValid() {
 				if l.Level >= logrus.DebugLevel {
-					l.WithField("localIp", ip).Debug("ip was invalid for netip")
+					l.WithField("localAddr", rawAddr).Debug("addr was invalid")
 				}
 				continue
 			}
-			nip = nip.Unmap()
+			addr = addr.Unmap()
 
-			//TODO: Filtering out link local for now, this is probably the most correct thing
-			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
-				allow := allowList.Allow(nip)
+			if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
+				isAllowed := allowList.Allow(addr)
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
 				}
-				if !allow {
+				if !isAllowed {
 					continue
 				}
 
-				ips = append(ips, nip)
+				finalAddrs = append(finalAddrs, addr)
 			}
 		}
 	}
-	return ips
+	return finalAddrs
 }

+ 23 - 33
hostmap_test.go

@@ -11,17 +11,14 @@ import (
 
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
 
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
-	h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
-	h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
+	h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
+	h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
@@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 }
 
@@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
 
-	hm := NewHostMapFromConfig(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-		c,
-	)
+	hm := NewHostMapFromConfig(l, c)
 
 	toS := func(ipn []netip.Prefix) []string {
 		var s []string

+ 2 - 2
hostmap_tester.go

@@ -9,8 +9,8 @@ import (
 	"net/netip"
 )
 
-func (i *HostInfo) GetVpnIp() netip.Addr {
-	return i.vpnIp
+func (i *HostInfo) GetVpnAddrs() []netip.Addr {
+	return i.vpnAddrs
 }
 
 func (i *HostInfo) GetLocalIndex() uint32 {

+ 31 - 27
inside.go

@@ -20,14 +20,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
-		return
+	if f.dropLocalBroadcast {
+		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
+		if found {
+			return
+		}
 	}
 
-	if fwPacket.RemoteIP == f.myVpnNet.Addr() {
+	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
+	if found {
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
-		// routes packets from the Nebula IP to the Nebula IP through the Nebula
+		// routes packets from the Nebula addr to the Nebula addr through the Nebula
 		// TUN device.
 		if immediatelyForwardToSelf {
 			_, err := f.readers[q].Write(packet)
@@ -36,25 +40,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 			}
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
-		// routes packets from the nebula IP to the nebula IP through the loopback device.
+		// routes packets from the nebula addr to the nebula addr through the loopback device.
 		return
 	}
 
 	// Ignore multicast packets
-	if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
+	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", fwPacket.RemoteIP).
+			f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
 				WithField("fwPacket", fwPacket).
-				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
+				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
 		return
 	}
@@ -117,21 +121,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 
-func (f *Interface) Handshake(vpnIp netip.Addr) {
-	f.getOrHandshake(vpnIp, nil)
+func (f *Interface) Handshake(vpnAddr netip.Addr) {
+	f.getOrHandshake(vpnAddr, nil)
 }
 
-// getOrHandshake returns nil if the vpnIp is not routable.
+// getOrHandshake returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !f.myVpnNet.Contains(vpnIp) {
-		vpnIp = f.inside.RouteFor(vpnIp)
-		if !vpnIp.IsValid() {
+func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
+	if !found {
+		vpnAddr = f.inside.RouteFor(vpnAddr)
+		if !vpnAddr.IsValid() {
 			return nil, false
 		}
 	}
 
-	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
+	return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 }
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -156,16 +161,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 
-// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
+	hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", vpnIp).
-				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
+			f.l.WithField("vpnAddr", vpnAddr).
+				Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
 		}
 		return
 	}
@@ -258,7 +263,6 @@ func (f *Interface) SendVia(via *HostInfo,
 
 func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
-		//TODO: log warning
 		return
 	}
 	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
@@ -285,14 +289,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	f.connectionManager.Out(hostinfo.localIndexId)
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
-	// all our IPs and enable a faster roaming.
+	// all our addrs and enable a faster roaming.
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp)
+		f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 	}
 
@@ -324,7 +328,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
 			if err != nil {
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

+ 83 - 60
interface.go

@@ -2,17 +2,16 @@ package nebula
 
 import (
 	"context"
-	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
-	"net"
 	"net/netip"
 	"os"
 	"runtime"
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -29,7 +28,6 @@ type InterfaceConfig struct {
 	Outside                 udp.Conn
 	Inside                  overlay.Device
 	pki                     *PKI
-	Cipher                  string
 	Firewall                *Firewall
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
@@ -53,25 +51,27 @@ type InterfaceConfig struct {
 }
 
 type Interface struct {
-	hostMap            *HostMap
-	outside            udp.Conn
-	inside             overlay.Device
-	pki                *PKI
-	cipher             string
-	firewall           *Firewall
-	connectionManager  *connectionManager
-	handshakeManager   *HandshakeManager
-	serveDns           bool
-	createTime         time.Time
-	lightHouse         *LightHouse
-	myBroadcastAddr    netip.Addr
-	myVpnNet           netip.Prefix
-	dropLocalBroadcast bool
-	dropMulticast      bool
-	routines           int
-	disconnectInvalid  atomic.Bool
-	closed             atomic.Bool
-	relayManager       *relayManager
+	hostMap               *HostMap
+	outside               udp.Conn
+	inside                overlay.Device
+	pki                   *PKI
+	firewall              *Firewall
+	connectionManager     *connectionManager
+	handshakeManager      *HandshakeManager
+	serveDns              bool
+	createTime            time.Time
+	lightHouse            *LightHouse
+	myBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
+	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	dropLocalBroadcast    bool
+	dropMulticast         bool
+	routines              int
+	disconnectInvalid     atomic.Bool
+	closed                atomic.Bool
+	relayManager          *relayManager
 
 	tryPromoteEvery atomic.Uint32
 	reQueryEvery    atomic.Uint32
@@ -103,9 +103,11 @@ type EncWriter interface {
 		out []byte,
 		nocopy bool,
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
+	SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp netip.Addr)
+	Handshake(vpnAddr netip.Addr)
+	GetHostInfo(vpnAddr netip.Addr) *HostInfo
+	GetCertState() *CertState
 }
 
 type sendRecvErrorConfig uint8
@@ -116,10 +118,10 @@ const (
 	sendRecvErrorPrivate
 )
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
 	switch s {
 	case sendRecvErrorPrivate:
-		return ip.Addr().IsPrivate()
+		return endpoint.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 		return true
 	case sendRecvErrorNever:
@@ -156,27 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 	}
 
-	certificate := c.pki.GetCertState().Certificate
-
+	cs := c.pki.getCertState()
 	ifce := &Interface{
-		pki:                c.pki,
-		hostMap:            c.HostMap,
-		outside:            c.Outside,
-		inside:             c.Inside,
-		cipher:             c.Cipher,
-		firewall:           c.Firewall,
-		serveDns:           c.ServeDns,
-		handshakeManager:   c.HandshakeManager,
-		createTime:         time.Now(),
-		lightHouse:         c.lightHouse,
-		dropLocalBroadcast: c.DropLocalBroadcast,
-		dropMulticast:      c.DropMulticast,
-		routines:           c.routines,
-		version:            c.version,
-		writers:            make([]udp.Conn, c.routines),
-		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnNet:           certificate.Networks()[0],
-		relayManager:       c.relayManager,
+		pki:                   c.pki,
+		hostMap:               c.HostMap,
+		outside:               c.Outside,
+		inside:                c.Inside,
+		firewall:              c.Firewall,
+		serveDns:              c.ServeDns,
+		handshakeManager:      c.HandshakeManager,
+		createTime:            time.Now(),
+		lightHouse:            c.lightHouse,
+		dropLocalBroadcast:    c.DropLocalBroadcast,
+		dropMulticast:         c.DropMulticast,
+		routines:              c.routines,
+		version:               c.version,
+		writers:               make([]udp.Conn, c.routines),
+		readers:               make([]io.ReadWriteCloser, c.routines),
+		myVpnNetworks:         cs.myVpnNetworks,
+		myVpnNetworksTable:    cs.myVpnNetworksTable,
+		myVpnAddrs:            cs.myVpnAddrs,
+		myVpnAddrsTable:       cs.myVpnAddrsTable,
+		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
+		relayManager:          c.relayManager,
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
@@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 	}
 
-	if ifce.myVpnNet.Addr().Is4() {
-		maskedAddr := certificate.Networks()[0].Masked()
-		addr := maskedAddr.Addr().As4()
-		mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen())
-		binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
-		ifce.myBroadcastAddr = netip.AddrFrom4(addr)
-	}
-
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -218,7 +214,7 @@ func (f *Interface) activate() {
 		f.l.WithError(err).Error("Failed to get udp listen address")
 	}
 
-	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
+	f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
@@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 
 	var li udp.Conn
-	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 		li = f.writers[i]
 	} else {
 		li = f.outside
 	}
 
+	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
+	plaintext := make([]byte, udp.MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+
+	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
+		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+	})
 }
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
@@ -408,6 +410,8 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 	udpStats := udp.NewUDPStatsEmitter(f.writers)
 
 	certExpirationGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil)
+	certDefaultVersion := metrics.GetOrRegisterGauge("certificate.default_version", nil)
+	certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil)
 
 	for {
 		select {
@@ -417,11 +421,30 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
-			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second))
+
+			certState := f.pki.getCertState()
+			defaultCrt := certState.GetDefaultCertificate()
+			certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second))
+			certDefaultVersion.Update(int64(defaultCrt.Version()))
+
+			// Report the max certificate version we are capable of using
+			if certState.v2Cert != nil {
+				certMaxVersion.Update(int64(certState.v2Cert.Version()))
+			} else {
+				certMaxVersion.Update(int64(certState.v1Cert.Version()))
+			}
 		}
 	}
 }
 
+func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return f.hostMap.QueryVpnAddr(vpnIp)
+}
+
+func (f *Interface) GetCertState() *CertState {
+	return f.pki.getCertState()
+}
+
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 

+ 0 - 2
iputil/packet.go

@@ -6,8 +6,6 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
-//TODO: IPV6-WORK can probably delete this
-
 const (
 	// Need 96 bytes for the largest reject packet:
 	// - 20 byte ipv4 header

File diff suppressed because it is too large
+ 364 - 240
lighthouse.go


+ 173 - 146
lighthouse_test.go

@@ -7,6 +7,8 @@ import (
 	"net/netip"
 	"testing"
 
+	"github.com/gaissmai/bart"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
@@ -14,62 +16,51 @@ import (
 	"gopkg.in/yaml.v2"
 )
 
-//TODO: Add a test to ensure udpAddr is copied and not reused
-
 func TestOldIPv4Only(t *testing.T) {
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
-	var m Ip4AndPort
+	var m V4AddrPort
 	err := m.Unmarshal(b)
 	assert.NoError(t, err)
 	ip := netip.MustParseAddr("10.1.1.1")
 	bp := ip.As4()
-	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
-}
-
-func TestNewLhQuery(t *testing.T) {
-	myIp, err := netip.ParseAddr("192.1.1.1")
-	assert.NoError(t, err)
-
-	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIp)
-
-	// The result should be a nebulameta protobuf
-	assert.IsType(t, &NebulaMeta{}, a)
-
-	// It should also Marshal fine
-	b, err := a.Marshal()
-	assert.Nil(t, err)
-
-	// and then Unmarshal fine
-	n := &NebulaMeta{}
-	err = n.Unmarshal(b)
-	assert.Nil(t, err)
-
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
 }
 
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
@@ -79,7 +70,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	}
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
 
@@ -99,9 +90,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	if !assert.NoError(b, err) {
 		b.Fatal()
 	}
@@ -110,46 +107,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
 
 	vpnIp3 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp3] = NewRemoteList(nil)
+	lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
 	lh.addrMap[vpnIp3].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
-			NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
+			netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
 	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
 	vpnIp2 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp2] = NewRemoteList(nil)
+	lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
 	lh.addrMap[vpnIp2].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
-			NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
+			netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	mw := &mockEncWriter{}
 
+	hi := []netip.Addr{vpnIp2}
 	b.Run("notfound", func(b *testing.B) {
 		lhh := lh.NewRequestHandler()
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
-				VpnIp:       4,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  4,
+				V4AddrPorts: nil,
 			},
 		}
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 	})
 	b.Run("found", func(b *testing.B) {
@@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
-				VpnIp:       3,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  3,
+				V4AddrPorts: nil,
 			},
 		}
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 	})
 }
@@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	lh.ifce = &mockEncWriter{}
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
 
 	// Ensure we don't accumulate addresses
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
 
 	// Grow it back to 2
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	// Update a different host and ask about it
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Have both hosts ask about the other
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 	// Make sure we didn't get changed
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 	// Ensure proper ordering and limiting
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 		t,
-		r.msg.Details.Ip4AndPorts,
+		r.msg.Details.V4AddrPorts,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 
@@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	good := netip.MustParseAddrPort("1.128.0.99:4242")
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
 }
 
 func TestLighthouse_reload(t *testing.T) {
@@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 
 	nc := map[interface{}]interface{}{
@@ -290,13 +306,16 @@ func TestLighthouse_reload(t *testing.T) {
 }
 
 func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
-	//TODO: IPV6-WORK
-	bip := queryVpnIp.As4()
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostQuery,
-		Details: &NebulaMetaDetails{
-			VpnIp: binary.BigEndian.Uint32(bip[:]),
-		},
+		Type:    NebulaMeta_HostQuery,
+		Details: &NebulaMetaDetails{},
+	}
+
+	if queryVpnIp.Is4() {
+		bip := queryVpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp)
 	}
 
 	b, err := req.Marshal()
@@ -308,23 +327,29 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
 	w := &testEncWriter{
 		metaFilter: &filter,
 	}
-	lhh.HandleRequest(fromAddr, myVpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
 	return w.lastReply
 }
 
 func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
-	//TODO: IPV6-WORK
-	bip := vpnIp.As4()
 	req := &NebulaMeta{
-		Type: NebulaMeta_HostUpdateNotification,
-		Details: &NebulaMetaDetails{
-			VpnIp:       binary.BigEndian.Uint32(bip[:]),
-			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
-		},
+		Type:    NebulaMeta_HostUpdateNotification,
+		Details: &NebulaMetaDetails{},
 	}
 
-	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
+	if vpnIp.Is4() {
+		bip := vpnIp.As4()
+		req.Details.OldVpnAddr = binary.BigEndian.Uint32(bip[:])
+	} else {
+		req.Details.VpnAddr = netAddrToProtoAddr(vpnIp)
+	}
+
+	for _, v := range addrs {
+		if v.Addr().Is4() {
+			req.Details.V4AddrPorts = append(req.Details.V4AddrPorts, netAddrToProtoV4AddrPort(v.Addr(), v.Port()))
+		} else {
+			req.Details.V6AddrPorts = append(req.Details.V6AddrPorts, netAddrToProtoV6AddrPort(v.Addr(), v.Port()))
+		}
 	}
 
 	b, err := req.Marshal()
@@ -333,75 +358,9 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
 	}
 
 	w := &testEncWriter{}
-	lhh.HandleRequest(fromAddr, vpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
 }
 
-//TODO: this is a RemoteList test
-//func Test_lhRemoteAllowList(t *testing.T) {
-//	l := NewLogger()
-//	c := NewConfig(l)
-//	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
-//		"10.20.0.0/12": false,
-//	}
-//	allowList, err := c.GetAllowList("remoteallowlist", false)
-//	assert.Nil(t, err)
-//
-//	lh1 := "10.128.0.2"
-//	lh1IP := net.ParseIP(lh1)
-//
-//	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-//
-//	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-//	lh.SetRemoteAllowList(allowList)
-//
-//	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
-//	remote1IP := net.ParseIP("10.20.0.3")
-//	remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
-//	remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
-//	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
-//	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
-//
-//	// Make sure a good ip enters the cache and addrMap
-//	remote2IP := net.ParseIP("10.128.0.3")
-//	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
-//
-//	// Another good ip gets into the cache, ordering is inverted
-//	remote3IP := net.ParseIP("10.128.0.4")
-//	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
-//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
-//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
-//
-//	// If we exceed the length limit we should only have the most recent addresses
-//	addedAddrs := []*udpAddr{}
-//	for i := 0; i < 11; i++ {
-//		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
-//		lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
-//		// The first entry here is a duplicate, don't add it to the assert list
-//		if i != 0 {
-//			addedAddrs = append(addedAddrs, remoteUDPAddr)
-//		}
-//	}
-//
-//	// We should only have the last 10 of what we tried to add
-//	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
-//	assertUdpAddrInArray(
-//		t,
-//		lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
-//		addedAddrs[0],
-//		addedAddrs[1],
-//		addedAddrs[2],
-//		addedAddrs[3],
-//		addedAddrs[4],
-//		addedAddrs[5],
-//		addedAddrs[6],
-//		addedAddrs[7],
-//		addedAddrs[8],
-//		addedAddrs[9],
-//	)
-//}
-
 type testLhReply struct {
 	nebType    header.MessageType
 	nebSubType header.MessageSubType
@@ -410,8 +369,9 @@ type testLhReply struct {
 }
 
 type testEncWriter struct {
-	lastReply  testLhReply
-	metaFilter *NebulaMeta_MessageType
+	lastReply       testLhReply
+	metaFilter      *NebulaMeta_MessageType
+	protocolVersion cert.Version
 }
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@@ -426,7 +386,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 		tw.lastReply = testLhReply{
 			nebType:    t,
 			nebSubType: st,
-			vpnIp:      hostinfo.vpnIp,
+			vpnIp:      hostinfo.vpnAddrs[0],
 			msg:        msg,
 		}
 	}
@@ -436,7 +396,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 }
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -453,17 +413,84 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	}
 }
 
+func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return nil
+}
+
+func (tw *testEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: tw.protocolVersion}
+}
+
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
+func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 		return
 	}
 
 	for k, w := range want {
-		//TODO: IPV6-WORK
-		h := AddrPortFromIp4AndPort(have[k])
+		h := protoV4AddrPortToNetAddrPort(have[k])
 		if !(h == w) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 	}
 }
+
+func Test_findNetworkUnion(t *testing.T) {
+	var out netip.Addr
+	var ok bool
+
+	tenDot := netip.MustParsePrefix("10.0.0.0/8")
+	oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
+	fe80 := netip.MustParsePrefix("fe80::/8")
+	fc00 := netip.MustParsePrefix("fc00::/7")
+
+	a1 := netip.MustParseAddr("10.0.0.1")
+	afe81 := netip.MustParseAddr("fe80::1")
+
+	//simple
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed lengths
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed family
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//ordering
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//some mismatches
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//falsey cases
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+}

+ 5 - 24
main.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"context"
-	"encoding/binary"
 	"fmt"
 	"net"
 	"net/netip"
@@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
 
-	certificate := pki.GetCertState().Certificate
-	fw, err := NewFirewallFromConfig(l, certificate, c)
+	fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
-	tunCidr := certificate.Networks()[0]
-
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			deviceFactory = overlay.NewDeviceFromConfig
 		}
 
-		tun, err = deviceFactory(c, l, tunCidr, routines)
+		tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
@@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	hostMap := NewHostMapFromConfig(l, tunCidr, c)
+	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
@@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		pki:                     pki,
-		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
@@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		l:                     l,
 	}
 
-	switch ifConfig.Cipher {
-	case "aes":
-		noiseEndianness = binary.BigEndian
-	case "chachapoly":
-		noiseEndianness = binary.LittleEndian
-	default:
-		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
-	}
-
 	var ifce *Interface
 	if !configTest {
 		ifce, err = NewInterface(ctx, ifConfig)
@@ -270,8 +256,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			return nil, fmt.Errorf("failed to initialize interface: %s", err)
 		}
 
-		// TODO: Better way to attach these, probably want a new interface in InterfaceConfig
-		// I don't want to make this initial commit too far-reaching though
 		ifce.writers = udpConns
 		lightHouse.ifce = ifce
 
@@ -283,8 +267,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		go handshakeManager.Run(ctx)
 	}
 
-	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
-	// a context so that they can exit when the context is Done.
 	statsStart, err := startStats(l, c, buildVersion, configTest)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
@@ -294,7 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, nil
 	}
 
-	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
 	attachCommands(l, c, ssh, ifce)
@@ -303,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	var dnsStart func()
 	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, c)
+		dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
 	}
 
 	return &Control{

+ 0 - 2
message_metrics.go

@@ -7,8 +7,6 @@ import (
 	"github.com/slackhq/nebula/header"
 )
 
-//TODO: this can probably move into the header package
-
 type MessageMetrics struct {
 	rx [][]metrics.Counter
 	tx [][]metrics.Counter

File diff suppressed because it is too large
+ 467 - 171
nebula.pb.go


+ 23 - 9
nebula.proto

@@ -23,19 +23,28 @@ message NebulaMeta {
 }
 
 message NebulaMetaDetails {
-  uint32 VpnIp = 1;
-  repeated Ip4AndPort Ip4AndPorts = 2;
-  repeated Ip6AndPort Ip6AndPorts = 4;
-  repeated uint32 RelayVpnIp = 5;
+  uint32 OldVpnAddr = 1 [deprecated = true];
+  Addr VpnAddr = 6;
+
+  repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
+  repeated Addr RelayVpnAddrs = 7;
+
+  repeated V4AddrPort V4AddrPorts = 2;
+  repeated V6AddrPort V6AddrPorts = 4;
   uint32 counter = 3;
 }
 
-message Ip4AndPort {
-  uint32 Ip = 1;
+message Addr {
+  uint64 Hi = 1;
+  uint64 Lo = 2;
+}
+
+message V4AddrPort {
+  uint32 Addr = 1;
   uint32 Port = 2;
 }
 
-message Ip6AndPort {
+message V6AddrPort {
   uint64 Hi = 1;
   uint64 Lo = 2;
   uint32 Port = 3;
@@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
   uint32 ResponderIndex = 3;
   uint64 Cookie = 4;
   uint64 Time = 5;
+  uint32 CertVersion = 8;
   // reserved for WIP multiport
   reserved 6, 7;
 }
@@ -76,6 +86,10 @@ message NebulaControl {
 
   uint32 InitiatorRelayIndex = 2;
   uint32 ResponderRelayIndex = 3;
-  uint32 RelayToIp = 4;
-  uint32 RelayFromIp = 5;
+
+  uint32 OldRelayToAddr = 4 [deprecated = true];
+  uint32 OldRelayFromAddr = 5 [deprecated = true];
+
+  Addr RelayToAddr = 6;
+  Addr RelayFromAddr = 7;
 }

+ 154 - 118
outside.go

@@ -3,16 +3,15 @@ package nebula
 import (
 	"encoding/binary"
 	"errors"
-	"fmt"
 	"net/netip"
 	"time"
 
-	"github.com/flynn/noise"
+	"github.com/google/gopacket/layers"
+	"golang.org/x/net/ipv6"
+
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 )
 
@@ -20,28 +19,9 @@ const (
 	minFwPacketLen = 4
 )
 
-// TODO: IPV6-WORK this can likely be removed now
-func readOutsidePackets(f *Interface) udp.EncReader {
-	return func(
-		addr netip.AddrPort,
-		out []byte,
-		packet []byte,
-		header *header.H,
-		fwPacket *firewall.Packet,
-		lhh udp.LightHouseHandlerFunc,
-		nb []byte,
-		q int,
-		localCache firewall.ConntrackCache,
-	) {
-		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
-	}
-}
-
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
-		// TODO: best if we return this and let caller log
-		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
 			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
@@ -51,7 +31,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	if ip.IsValid() {
-		if f.myVpnNet.Contains(ip.Addr()) {
+		_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
+		if found {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
@@ -108,7 +89,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			if !ok {
 				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
 				// its internal mapping. This should never happen.
-				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
 				return
 			}
 
@@ -120,9 +101,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
-				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
+				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
 				if err != nil {
-					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
+					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
 					return
 				}
 
@@ -138,7 +119,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 					}
 				} else {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
+					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 					return
 				}
 			}
@@ -155,13 +136,10 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
-
-			//TODO: maybe after build 64 is out? 06/14/2018 - NB
-			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
 			return
 		}
 
-		lhf(ip, hostinfo.vpnIp, d)
+		lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
 
 		// Fallthrough to the bottom to record incoming traffic
 
@@ -176,9 +154,6 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
-
-			//TODO: maybe after build 64 is out? 06/14/2018 - NB
-			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
 			return
 		}
 
@@ -228,14 +203,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				Error("Failed to decrypt Control packet")
 			return
 		}
-		m := &NebulaControl{}
-		err = m.Unmarshal(d)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
-			break
-		}
 
-		f.relayManager.HandleControlMsg(hostinfo, m, f)
+		f.relayManager.HandleControlMsg(hostinfo, d, f)
 
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -252,8 +221,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	if final {
-		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
-		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
+		// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
+		f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
 	}
 }
 
@@ -262,25 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
-	if ip.IsValid() && hostinfo.remote != ip {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
+	if udpAddr.IsValid() && hostinfo.remote != udpAddr {
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
-		if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+
+		if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(ip)
+		hostinfo.SetRemote(udpAddr)
 	}
 
 }
@@ -300,24 +270,141 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
 	return true
 }
 
+var (
+	ErrPacketTooShort          = errors.New("packet is too short")
+	ErrUnknownIPVersion        = errors.New("packet is an unknown ip version")
+	ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
+	ErrIPv4PacketTooShort      = errors.New("ipv4 packet is too short")
+	ErrIPv6PacketTooShort      = errors.New("ipv6 packet is too short")
+	ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
+)
+
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
-	// Do we at least have an ipv4 header worth of data?
-	if len(data) < ipv4.HeaderLen {
-		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
+	if len(data) < 1 {
+		return ErrPacketTooShort
+	}
+
+	version := int((data[0] >> 4) & 0x0f)
+	switch version {
+	case ipv4.Version:
+		return parseV4(data, incoming, fp)
+	case ipv6.Version:
+		return parseV6(data, incoming, fp)
+	}
+	return ErrUnknownIPVersion
+}
+
+func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
+	dataLen := len(data)
+	if dataLen < ipv6.HeaderLen {
+		return ErrIPv6PacketTooShort
+	}
+
+	if incoming {
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
+	} else {
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
+	}
+
+	protoAt := 6             // NextHeader is at 6 bytes into the ipv6 header
+	offset := ipv6.HeaderLen // Start at the end of the ipv6 header
+	next := 0
+	for {
+		if dataLen < offset {
+			break
+		}
+
+		proto := layers.IPProtocol(data[protoAt])
+		//fmt.Println(proto, protoAt)
+		switch proto {
+		case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
+			fp.Protocol = uint8(proto)
+			fp.RemotePort = 0
+			fp.LocalPort = 0
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolTCP, layers.IPProtocolUDP:
+			if dataLen < offset+4 {
+				return ErrIPv6PacketTooShort
+			}
+
+			fp.Protocol = uint8(proto)
+			if incoming {
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			} else {
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			}
+
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolIPv6Fragment:
+			// Fragment header is 8 bytes, need at least offset+4 to read the offset field
+			if dataLen < offset+8 {
+				return ErrIPv6PacketTooShort
+			}
+
+			// Check if this is the first fragment
+			fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
+			if fragmentOffset != 0 {
+				// Non-first fragment, use what we have now and stop processing
+				fp.Protocol = data[offset]
+				fp.Fragment = true
+				fp.RemotePort = 0
+				fp.LocalPort = 0
+				return nil
+			}
+
+			// The next loop should be the transport layer since we are the first fragment
+			next = 8 // Fragment headers are always 8 bytes
+
+		case layers.IPProtocolAH:
+			// Auth headers, used by IPSec, have a different meaning for header length
+			if dataLen < offset+1 {
+				break
+			}
+
+			next = int(data[offset+1]+2) << 2
+
+		default:
+			// Normal ipv6 header length processing
+			if dataLen < offset+1 {
+				break
+			}
+
+			next = int(data[offset+1]+1) << 3
+		}
+
+		if next <= 0 {
+			// Safety check, each ipv6 header has to be at least 8 bytes
+			next = 8
+		}
+
+		protoAt = offset
+		offset = offset + next
 	}
 
-	// Is it an ipv4 packet?
-	if int((data[0]>>4)&0x0f) != 4 {
-		return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
+	return ErrIPv6CouldNotFindPayload
+}
+
+func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
+	// Do we at least have an ipv4 header worth of data?
+	if len(data) < ipv4.HeaderLen {
+		return ErrIPv4PacketTooShort
 	}
 
 	// Adjust our start position based on the advertised ip header length
 	ihl := int(data[0]&0x0f) << 2
 
-	// Well formed ip header length?
+	// Well-formed ip header length?
 	if ihl < ipv4.HeaderLen {
-		return fmt.Errorf("packet had an invalid header length: %v", ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 
 	// Check if this is the second or further fragment of a fragmented packet.
@@ -333,14 +420,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 		minLen += minFwPacketLen
 	}
 	if len(data) < minLen {
-		return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
+		return ErrIPv4InvalidHeaderLength
 	}
 
 	// Firewall packets are locally oriented
 	if incoming {
-		//TODO: IPV6-WORK
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
@@ -349,9 +435,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 	} else {
-		//TODO: IPV6-WORK
-		fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
@@ -386,8 +471,6 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
-		//TODO: maybe after build 64 is out? 06/14/2018 - NB
-		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
 		return false
 	}
 
@@ -434,9 +517,8 @@ func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
 func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
-	//TODO: this should be a signed message so we can trust that we should drop the index
 	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
-	f.outside.WriteTo(b, endpoint)
+	_ = f.outside.WriteTo(b, endpoint)
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", index).
 			WithField("udpAddr", endpoint).
@@ -470,49 +552,3 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 	// We also delete it from pending hostmap to allow for fast reconnect.
 	f.handshakeManager.DeleteHostInfo(hostinfo)
 }
-
-/*
-func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
-	if ci.eKey != nil {
-		//TODO: log error?
-		return
-	}
-
-	msg, err := proto.Marshal(meta)
-	if err != nil {
-		l.Debugln("failed to encode header")
-	}
-
-	c := ci.messageCounter
-	b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
-	ci.messageCounter++
-
-	msg := ci.eKey.EncryptDanger(b, nil, msg, c)
-	//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
-	f.outside.WriteTo(msg, endpoint)
-}
-*/
-
-func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) {
-	pk := h.PeerStatic()
-
-	if pk == nil {
-		return nil, errors.New("no peer static key was present")
-	}
-
-	if rawCertBytes == nil {
-		return nil, errors.New("provided payload was empty")
-	}
-
-	c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk)
-	if err != nil {
-		return nil, fmt.Errorf("error unmarshaling cert: %w", err)
-	}
-
-	cc, err := caPool.VerifyCertificate(time.Now(), c)
-	if err != nil {
-		return nil, fmt.Errorf("certificate validation failed: %w", err)
-	}
-
-	return cc, nil
-}

+ 525 - 17
outside_test.go

@@ -1,10 +1,15 @@
 package nebula
 
 import (
+	"bytes"
+	"encoding/binary"
 	"net"
 	"net/netip"
 	"testing"
 
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/layers"
+
 	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
@@ -13,9 +18,15 @@ import (
 func Test_newPacket(t *testing.T) {
 	p := &firewall.Packet{}
 
-	// length fail
-	err := newPacket([]byte{0, 1}, true, p)
-	assert.EqualError(t, err, "packet is less than 20 bytes")
+	// length fails
+	err := newPacket([]byte{}, true, p)
+	assert.ErrorIs(t, err, ErrPacketTooShort)
+
+	err = newPacket([]byte{0x40}, true, p)
+	assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
+
+	err = newPacket([]byte{0x60}, true, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
 
 	// length fail with ip options
 	h := ipv4.Header{
@@ -28,16 +39,15 @@ func Test_newPacket(t *testing.T) {
 
 	b, _ := h.Marshal()
 	err = newPacket(b, true, p)
-
-	assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 	// not an ipv4 packet
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet is not ipv4, type: 0")
+	assert.ErrorIs(t, err, ErrUnknownIPVersion)
 
 	// invalid ihl
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet had an invalid header length: 8")
+	assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
 
 	// account for variable ip header length - incoming
 	h = ipv4.Header{
@@ -54,11 +64,12 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, true, p)
 
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemotePort, uint16(3))
-	assert.Equal(t, p.LocalPort, uint16(4))
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
+	assert.Equal(t, uint16(3), p.RemotePort)
+	assert.Equal(t, uint16(4), p.LocalPort)
+	assert.False(t, p.Fragment)
 
 	// account for variable ip header length - outgoing
 	h = ipv4.Header{
@@ -75,9 +86,506 @@ func Test_newPacket(t *testing.T) {
 	err = newPacket(b, false, p)
 
 	assert.Nil(t, err)
-	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemotePort, uint16(6))
-	assert.Equal(t, p.LocalPort, uint16(5))
+	assert.Equal(t, uint8(2), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
+	assert.Equal(t, uint16(6), p.RemotePort)
+	assert.Equal(t, uint16(5), p.LocalPort)
+	assert.False(t, p.Fragment)
+}
+
+func Test_newPacket_v6(t *testing.T) {
+	p := &firewall.Packet{}
+
+	// invalid ipv6
+	ip := layers.IPv6{
+		Version:  6,
+		HopLimit: 128,
+		SrcIP:    net.IPv6linklocalallrouters,
+		DstIP:    net.IPv6linklocalallnodes,
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opt := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       false,
+	}
+	err := gopacket.SerializeLayers(buffer, opt, &ip)
+	assert.NoError(t, err)
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good ICMP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolICMPv6,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	icmp := layers.ICMPv6{}
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
+	if err != nil {
+		panic(err)
+	}
+
+	err = newPacket(buffer.Bytes(), true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good ESP packet
+	b := buffer.Bytes()
+	b[6] = byte(layers.IPProtocolESP)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// A good None packet
+	b = buffer.Bytes()
+	b[6] = byte(layers.IPProtocolNoNextHeader)
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// An unknown protocol packet
+	b = buffer.Bytes()
+	b[6] = 255 // 255 is a reserved protocol number
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+
+	// A good UDP packet
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: firewall.ProtoUDP,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+	err = udp.SetNetworkLayerForChecksum(&ip)
+	assert.NoError(t, err)
+
+	buffer.Clear()
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
+	if err != nil {
+		panic(err)
+	}
+	b = buffer.Bytes()
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short UDP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good TCP packet
+	b[6] = byte(layers.IPProtocolTCP)
+
+	// incoming
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// outgoing
+	err = newPacket(b, false, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Too short TCP packet
+	err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+
+	// A good UDP packet with an AH header
+	ip = layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolAH,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	ah := layers.IPSecAH{
+		AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
+	}
+	ah.NextHeader = layers.IPProtocolUDP
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opt)
+	if err != nil {
+		panic(err)
+	}
+
+	b = buffer.Bytes()
+	ahb := serializeAH(&ah)
+	b = append(b, ahb...)
+	b = append(b, udpHeader...)
+
+	err = newPacket(b, true, p)
+	assert.Nil(t, err)
+	assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Invalid AH header
+	b = buffer.Bytes()
+	err = newPacket(b, true, p)
+	assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
+}
+
+func Test_newPacket_ipv6Fragment(t *testing.T) {
+	p := &firewall.Packet{}
+
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	// First fragment
+	fragHeader1 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x1b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: true,
+		FixLengths:       true,
+	}
+
+	err := ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader1...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test first fragment incoming
+	err = newPacket(firstFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.RemotePort)
+	assert.Equal(t, uint16(22), p.LocalPort)
+	assert.False(t, p.Fragment)
+
+	// Test first fragment outgoing
+	err = newPacket(firstFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(36123), p.LocalPort)
+	assert.Equal(t, uint16(22), p.RemotePort)
+	assert.False(t, p.Fragment)
+
+	// Second fragment
+	fragHeader2 := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0xb9,                        // Fragment Offset high byte (185)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	buffer.Clear()
+	err = ip.SerializeTo(buffer, opts)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader2...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Test second fragment incoming
+	err = newPacket(secondFrag, true, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.True(t, p.Fragment)
+
+	// Test second fragment outgoing
+	err = newPacket(secondFrag, false, p)
+	assert.NoError(t, err)
+	assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
+	assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
+	assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
+	assert.Equal(t, uint16(0), p.LocalPort)
+	assert.Equal(t, uint16(0), p.RemotePort)
+	assert.True(t, p.Fragment)
+
+	// Too short of a fragment packet
+	err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
+	assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
+}
+
+func BenchmarkParseV6(b *testing.B) {
+	// Regular UDP packet
+	ip := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolUDP,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := &layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opts := gopacket.SerializeOptions{
+		ComputeChecksums: false,
+		FixLengths:       true,
+	}
+
+	err := gopacket.SerializeLayers(buffer, opts, ip, udp)
+	if err != nil {
+		b.Fatal(err)
+	}
+	normalPacket := buffer.Bytes()
+
+	// First Fragment packet
+	ipFrag := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6Fragment,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	fragHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Reserved
+		0x00,                        // Fragment Offset high byte (0)
+		0x01,                        // Fragment Offset low byte & flags (M=1)
+		0x00, 0x00, 0x00, 0x01,      // Identification
+	}
+
+	udpHeader := []byte{
+		0x8d, 0x7b, // Source port 36123
+		0x00, 0x16, // Destination port 22
+		0x00, 0x00, // Length
+		0x00, 0x00, // Checksum
+	}
+
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	firstFrag := buffer.Bytes()
+	firstFrag = append(firstFrag, fragHeader...)
+	firstFrag = append(firstFrag, udpHeader...)
+	firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	// Second Fragment packet
+	fragHeader[2] = 0xb9 // offset 185
+	buffer.Clear()
+	err = ipFrag.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	secondFrag := buffer.Bytes()
+	secondFrag = append(secondFrag, fragHeader...)
+	secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	fp := &firewall.Packet{}
+
+	b.Run("Normal", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(normalPacket, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("FirstFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(firstFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	b.Run("SecondFragment", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(secondFrag, true, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+
+	// Evil packet
+	evilPacket := &layers.IPv6{
+		Version:    6,
+		NextHeader: layers.IPProtocolIPv6HopByHop,
+		HopLimit:   64,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	hopHeader := []byte{
+		uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
+		0x00,                                 // Length
+		0x00, 0x00,                           // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	lastHopHeader := []byte{
+		uint8(layers.IPProtocolUDP), // Next Header (UDP)
+		0x00,                        // Length
+		0x00, 0x00,                  // Options and padding
+		0x00, 0x00, 0x00, 0x00, // More options and padding
+	}
+
+	buffer.Clear()
+	err = evilPacket.SerializeTo(buffer, opts)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	evilBytes := buffer.Bytes()
+	for i := 0; i < 200; i++ {
+		evilBytes = append(evilBytes, hopHeader...)
+	}
+	evilBytes = append(evilBytes, lastHopHeader...)
+	evilBytes = append(evilBytes, udpHeader...)
+	evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
+
+	b.Run("200 HopByHop headers", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			if err = parseV6(evilBytes, false, fp); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+}
+
+// Ensure authentication data is a multiple of 8 bytes by padding if necessary
+func padAuthData(authData []byte) []byte {
+	// Length of Authentication Data must be a multiple of 8 bytes
+	paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
+	if paddingLength > 0 {
+		authData = append(authData, make([]byte, paddingLength)...)
+	}
+	return authData
+}
+
+// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
+func serializeAH(ah *layers.IPSecAH) []byte {
+	buf := new(bytes.Buffer)
+
+	// Ensure Authentication Data is a multiple of 8 bytes
+	ah.AuthenticationData = padAuthData(ah.AuthenticationData)
+	// Calculate Payload Length (in 32-bit words, minus 2)
+	payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
+
+	// Serialize fields
+	if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
+		panic(err)
+	}
+	if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
+		panic(err)
+	}
+	if len(ah.AuthenticationData) > 0 {
+		if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
+			panic(err)
+		}
+	}
+
+	return buf.Bytes()
 }

+ 1 - 1
overlay/device.go

@@ -8,7 +8,7 @@ import (
 type Device interface {
 	io.ReadWriteCloser
 	Activate() error
-	Cidr() netip.Prefix
+	Networks() []netip.Prefix
 	Name() string
 	RouteFor(netip.Addr) netip.Addr
 	NewMultiQueueReader() (io.ReadWriteCloser, error)

+ 22 - 12
overlay/route.go

@@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
 	return routeTree, nil
 }
 
-func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 
 	r := c.Get("tun.routes")
@@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 		}
 
-		if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
+		found := false
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
+				found = true
+				break
+			}
+		}
+
+		if !found {
 			return nil, fmt.Errorf(
-				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
+				"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
 				i+1,
 				r.Cidr.String(),
-				network.String(),
+				networks,
 			)
 		}
 
@@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	return routes, nil
 }
 
-func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 
 	r := c.Get("tun.unsafe_routes")
@@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 		}
 
-		if network.Contains(r.Cidr.Addr()) {
-			return nil, fmt.Errorf(
-				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
-				i+1,
-				r.Cidr.String(),
-				network.String(),
-			)
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) {
+				return nil, fmt.Errorf(
+					"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
+					i+1,
+					r.Cidr.String(),
+					network.String(),
+				)
+			}
 		}
 
 		routes[i] = r

+ 39 - 33
overlay/route_test.go

@@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
 	assert.NoError(t, err)
 
 	// test no routes config
-	routes, err := parseRoutes(c, n)
+	routes, err := parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.routes is not an array")
 
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
 
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
+
+	// Not in multiple ranges
+	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
+	routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
+	assert.Nil(t, routes)
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
 
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 2)
 
@@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	assert.NoError(t, err)
 
 	// test no routes config
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 	// no via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
@@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		127, false, nil, 1.0, []string{"1", "2"},
 	} {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
-		routes, err = parseUnsafeRoutes(c, n)
+		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 	}
 
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Equal(t, 0, routes[0].MTU)
 
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 	// bad install
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 
@@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Len(t, routes, 4)
 
@@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 	}}
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.NoError(t, err)
 	assert.Len(t, routes, 2)
 	routeTree, err := makeRouteTree(l, routes, true)

+ 9 - 9
overlay/tun.go

@@ -11,36 +11,36 @@ import (
 const DefaultMTU = 1300
 
 // TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
 	switch {
 	case c.GetBool("tun.disabled", false):
-		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
+		tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 
 	default:
-		return newTun(c, l, tunCidr, routines > 1)
+		return newTun(c, l, vpnNetworks, routines > 1)
 	}
 }
 
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
-		return newTunFromFd(c, l, *fd, tunCidr)
+	return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+		return newTunFromFd(c, l, *fd, vpnNetworks)
 	}
 }
 
-func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 		return false, nil, nil
 	}
 
-	routes, err := parseRoutes(c, cidr)
+	routes, err := parseRoutes(c, vpnNetworks)
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 	}
 
-	unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
+	unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}

+ 11 - 11
overlay/tun_android.go

@@ -18,14 +18,14 @@ import (
 
 type tun struct {
 	io.ReadWriteCloser
-	fd        int
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	fd          int
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	t := &tun{
 		ReadWriteCloser: file,
 		fd:              deviceFd,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		l:               l,
 	}
 
@@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 }
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 
@@ -66,7 +66,7 @@ func (t tun) Activate() error {
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 208 - 215
overlay/tun_darwin.go

@@ -24,56 +24,62 @@ import (
 
 type tun struct {
 	io.ReadWriteCloser
-	Device     string
-	cidr       netip.Prefix
-	DefaultMTU int
-	Routes     atomic.Pointer[[]Route]
-	routeTree  atomic.Pointer[bart.Table[netip.Addr]]
-	linkAddr   *netroute.LinkAddr
-	l          *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	DefaultMTU  int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	linkAddr    *netroute.LinkAddr
+	l           *logrus.Logger
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	out []byte
 }
 
-type sockaddrCtl struct {
-	scLen      uint8
-	scFamily   uint8
-	ssSysaddr  uint16
-	scID       uint32
-	scUnit     uint32
-	scReserved [5]uint32
-}
-
 type ifReq struct {
-	Name  [16]byte
+	Name  [unix.IFNAMSIZ]byte
 	Flags uint16
 	pad   [8]byte
 }
 
-var sockaddrCtlSize uintptr = 32
-
 const (
-	_SYSPROTO_CONTROL = 2              //define SYSPROTO_CONTROL 2 /* kernel control protocol */
-	_AF_SYS_CONTROL   = 2              //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
-	_PF_SYSTEM        = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
-	_CTLIOCGINFO      = 3227799043     //#define CTLIOCGINFO     _IOWR('N', 3, struct ctl_info)
-	utunControlName   = "com.apple.net.utun_control"
+	_SIOCAIFADDR_IN6 = 2155899162
+	_UTUN_OPT_IFNAME = 2
+	_IN6_IFF_NODAD   = 0x0020
+	_IN6_IFF_SECURED = 0x0400
+	utunControlName  = "com.apple.net.utun_control"
 )
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 	Name [16]byte
 	MTU  int32
 	pad  [8]byte
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+type addrLifetime struct {
+	Expire    float64
+	Preferred float64
+	Vltime    uint32
+	Pltime    uint32
+}
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+}
+
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	name := c.GetString("tun.dev", "")
 	ifIndex := -1
 	if name != "" && name != "utun" {
@@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 		}
 	}
 
-	fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
+	fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
 	if err != nil {
 		return nil, fmt.Errorf("system socket: %v", err)
 	}
 
-	var ctlInfo = &struct {
-		ctlID   uint32
-		ctlName [96]byte
-	}{}
+	var ctlInfo = &unix.CtlInfo{}
+	copy(ctlInfo.Name[:], utunControlName)
 
-	copy(ctlInfo.ctlName[:], utunControlName)
-
-	err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
+	err = unix.IoctlCtlInfo(fd, ctlInfo)
 	if err != nil {
 		return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
 	}
 
-	sc := sockaddrCtl{
-		scLen:     uint8(sockaddrCtlSize),
-		scFamily:  unix.AF_SYSTEM,
-		ssSysaddr: _AF_SYS_CONTROL,
-		scID:      ctlInfo.ctlID,
-		scUnit:    uint32(ifIndex) + 1,
-	}
-
-	_, _, errno := unix.RawSyscall(
-		unix.SYS_CONNECT,
-		uintptr(fd),
-		uintptr(unsafe.Pointer(&sc)),
-		sockaddrCtlSize,
-	)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
+	err = unix.Connect(fd, &unix.SockaddrCtl{
+		ID:   ctlInfo.Id,
+		Unit: uint32(ifIndex) + 1,
+	})
+	if err != nil {
+		return nil, fmt.Errorf("SYS_CONNECT: %v", err)
 	}
 
-	var ifName struct {
-		name [16]byte
-	}
-	ifNameSize := uintptr(len(ifName.name))
-	_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
-		2, // SYSPROTO_CONTROL
-		2, // UTUN_OPT_IFNAME
-		uintptr(unsafe.Pointer(&ifName)),
-		uintptr(unsafe.Pointer(&ifNameSize)), 0)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
+	name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
+	if err != nil {
+		return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
 	}
-	name = string(ifName.name[:ifNameSize-1])
 
-	err = syscall.SetNonblock(fd, true)
+	err = unix.SetNonblock(fd, true)
 	if err != nil {
 		return nil, fmt.Errorf("SetNonblock: %v", err)
 	}
 
-	file := os.NewFile(uintptr(fd), "")
-
 	t := &tun{
-		ReadWriteCloser: file,
+		ReadWriteCloser: os.NewFile(uintptr(fd), ""),
 		Device:          name,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		DefaultMTU:      c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 
@@ -186,16 +167,6 @@ func (t *tun) Close() error {
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 
-	var addr, mask [4]byte
-
-	if !t.cidr.Addr().Is4() {
-		//TODO: IPV6-WORK
-		panic("need ipv6")
-	}
-
-	addr = t.cidr.Addr().As4()
-	copy(mask[:], prefixToMask(t.cidr))
-
 	s, err := unix.Socket(
 		unix.AF_INET,
 		unix.SOCK_DGRAM,
@@ -208,66 +179,18 @@ func (t *tun) Activate() error {
 
 	fd := uintptr(s)
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
-	// Set the device name
-	ifrf := ifReq{Name: devName}
-	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to set tun device name: %s", err)
-	}
-
 	// Set the MTU on the device
 	ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 		return fmt.Errorf("failed to set tun mtu: %v", err)
 	}
 
-	/*
-		// Set the transmit queue length
-		ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
-		if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
-			// If we can't set the queue length nebula will still work but it may lead to packet loss
-			l.WithError(err).Error("Failed to set tun tx queue length")
-		}
-	*/
-
-	// Bring up the interface
-	ifrf.Flags = ifrf.Flags | unix.IFF_UP
-	if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to bring the tun device up: %s", err)
-	}
-
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	// Get the device flags
+	ifrf := ifReq{Name: devName}
+	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+		return fmt.Errorf("failed to get tun flags: %s", err)
 	}
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
 
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	linkAddr, err := getLinkAddr(t.Device)
 	if err != nil {
 		return err
@@ -277,14 +200,18 @@ func (t *tun) Activate() error {
 	}
 	t.linkAddr = linkAddr
 
-	copy(routeAddr.IP[:], addr[:])
-	copy(maskAddr.IP[:], mask[:])
-	err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
-	if err != nil {
-		if errors.Is(err, unix.EEXIST) {
-			err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
+	for _, network := range t.vpnNetworks {
+		if network.Addr().Is4() {
+			err = t.activate4(network)
+			if err != nil {
+				return err
+			}
+		} else {
+			err = t.activate6(network)
+			if err != nil {
+				return err
+			}
 		}
-		return err
 	}
 
 	// Run the interface
@@ -297,8 +224,89 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) activate4(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias4{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		DstAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		MaskAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(network).As4(),
+		},
+	}
+
+	if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun v4 address: %s", err)
+	}
+
+	err = addRoute(network, t.linkAddr)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (t *tun) activate6(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET6,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias6{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   network.Addr().As16(),
+		},
+		PrefixMask: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(network).As16(),
+		},
+		Lifetime: addrLifetime{
+			// never expires
+			Vltime: 0xffffffff,
+			Pltime: 0xffffffff,
+		},
+		//TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address?
+		Flags: _IN6_IFF_NODAD,
+	}
+
+	if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun address: %s", err)
+	}
+
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -343,7 +351,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 }
 
 // Get the LinkAddr for the interface of the given name
-// TODO: Is there an easier way to fetch this when we create the interface?
+// Is there an easier way to fetch this when we create the interface?
 // Maybe SIOCGIFINDEX? but this doesn't appear to exist in the darwin headers.
 func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 	rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0)
@@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 }
 
 func (t *tun) addRoutes(logErrors bool) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		if !r.Cidr.Addr().Is4() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		//TODO: we could avoid the copy
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := addRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 			if errors.Is(err, unix.EEXIST) {
 				t.l.WithField("route", r.Cidr).
@@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
 }
 
 func (t *tun) removeRoutes(routes []Route) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
-
 	for _, r := range routes {
 		if !r.Install {
 			continue
 		}
 
-		if r.Cidr.Addr().Is6() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := delRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
@@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_ADD,
 		Flags:   unix.RTF_UP,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
+
 	_, err = unix.Write(sock, data[:])
 	if err != nil {
 		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 	return nil
 }
 
-func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_DELETE,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
@@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 }
 
 func (t *tun) Read(to []byte) (int, error) {
-
 	buf := make([]byte, len(to)+4)
 
 	n, err := t.ReadWriteCloser.Read(buf)
@@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
 	return n - 4, err
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {
@@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
 
-func prefixToMask(prefix netip.Prefix) []byte {
+func prefixToMask(prefix netip.Prefix) netip.Addr {
 	pLen := 128
 	if prefix.Addr().Is4() {
 		pLen = 32
 	}
-	return net.CIDRMask(prefix.Bits(), pLen)
+
+	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
+	return addr
 }

+ 8 - 8
overlay/tun_disabled.go

@@ -12,8 +12,8 @@ import (
 )
 
 type disabledTun struct {
-	read chan []byte
-	cidr netip.Prefix
+	read        chan []byte
+	vpnNetworks []netip.Prefix
 
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
@@ -21,11 +21,11 @@ type disabledTun struct {
 	l  *logrus.Logger
 }
 
-func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
-		cidr: cidr,
-		read: make(chan []byte, queueLen),
-		l:    l,
+		vpnNetworks: vpnNetworks,
+		read:        make(chan []byte, queueLen),
+		l:           l,
 	}
 
 	if metricsEnabled {
@@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
 	return netip.Addr{}
 }
 
-func (t *disabledTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *disabledTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (*disabledTun) Name() string {

+ 25 - 15
overlay/tun_freebsd.go

@@ -46,12 +46,12 @@ type ifreqDestroy struct {
 }
 
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 	io.ReadWriteCloser
 }
@@ -78,11 +78,11 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
 	var file *os.File
 	var err error
@@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 		ReadWriteCloser: file,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 }
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -195,8 +195,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 10 - 10
overlay/tun_ios.go

@@ -21,20 +21,20 @@ import (
 
 type tun struct {
 	io.ReadWriteCloser
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	t := &tun{
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		ReadWriteCloser: &tunReadCloser{f: file},
 		l:               l,
 	}
@@ -59,7 +59,7 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
 	return tr.f.Close()
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 120 - 78
overlay/tun_linux.go

@@ -11,6 +11,7 @@ import (
 	"os"
 	"strings"
 	"sync/atomic"
+	"time"
 	"unsafe"
 
 	"github.com/gaissmai/bart"
@@ -25,7 +26,7 @@ type tun struct {
 	io.ReadWriteCloser
 	fd          int
 	Device      string
-	cidr        netip.Prefix
+	vpnNetworks []netip.Prefix
 	MaxMTU      int
 	DefaultMTU  int
 	TXQueueLen  int
@@ -40,18 +41,16 @@ type tun struct {
 	l *logrus.Logger
 }
 
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
+
 type ifReq struct {
 	Name  [16]byte
 	Flags uint16
 	pad   [8]byte
 }
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 	Name [16]byte
 	MTU  int32
@@ -64,10 +63,10 @@ type ifreqQLEN struct {
 	pad   [8]byte
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 		return nil, err
 	}
@@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	name := strings.Trim(string(req.Name[:]), "\x00")
 
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 		return nil, err
 	}
@@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	return t, nil
 }
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
 	t := &tun{
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
 		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
 		l:               l,
@@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
 		}
 
 		if oldDefaultMTU != newDefaultMTU {
-			err := t.setDefaultRoute()
-			if err != nil {
-				t.l.Warn(err)
-			} else {
-				t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+			for i := range t.vpnNetworks {
+				err := t.setDefaultRoute(t.vpnNetworks[i])
+				if err != nil {
+					t.l.Warn(err)
+				} else {
+					t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+				}
 			}
 		}
 
@@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 
 func (t *tun) Write(b []byte) (int, error) {
 	var nn int
-	max := len(b)
+	maximum := len(b)
 
 	for {
-		n, err := unix.Write(t.fd, b[nn:max])
+		n, err := unix.Write(t.fd, b[nn:maximum])
 		if n > 0 {
 			nn += n
 		}
@@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
+func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
+	for i := range al {
+		if al[i].Equal(x) {
+			return true
+		}
+	}
+	return false
+}
+
+// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
+func (t *tun) addIPs(link netlink.Link) error {
+	newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
+	for i := range t.vpnNetworks {
+		newAddrs[i] = &netlink.Addr{
+			IPNet: &net.IPNet{
+				IP:   t.vpnNetworks[i].Addr().AsSlice(),
+				Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
+			},
+			Label: t.vpnNetworks[i].Addr().Zone(),
+		}
+	}
+
+	//add all new addresses
+	for i := range newAddrs {
+		//TODO: CERT-V2 do we want to stack errors and try as many ops as possible?
+		//AddrReplace still adds new IPs, but if their properties change it will change them as well
+		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
+			return err
+		}
+	}
+
+	//iterate over remainder, remove whoever shouldn't be there
+	al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
+	if err != nil {
+		return fmt.Errorf("failed to get tun address list: %s", err)
+	}
+
+	for i := range al {
+		if hasNetlinkAddr(newAddrs, al[i]) {
+			continue
+		}
+		err = netlink.AddrDel(link, &al[i])
+		if err != nil {
+			t.l.WithError(err).Error("failed to remove address from tun address list")
+		} else {
+			t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
+		}
+	}
+
+	return nil
+}
+
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 
@@ -272,15 +325,8 @@ func (t *tun) Activate() error {
 		t.watchRoutes()
 	}
 
-	var addr, mask [4]byte
-
-	//TODO: IPV6-WORK
-	addr = t.cidr.Addr().As4()
-	tmask := net.CIDRMask(t.cidr.Bits(), 32)
-	copy(mask[:], tmask)
-
 	s, err := unix.Socket(
-		unix.AF_INET,
+		unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
 		unix.SOCK_DGRAM,
 		unix.IPPROTO_IP,
 	)
@@ -289,31 +335,19 @@ func (t *tun) Activate() error {
 	}
 	t.ioctlFd = uintptr(s)
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
 	// Set the device name
 	ifrf := ifReq{Name: devName}
 	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to set tun device name: %s", err)
 	}
 
+	link, err := netlink.LinkByName(t.Device)
+	if err != nil {
+		return fmt.Errorf("failed to get tun device link: %s", err)
+	}
+
+	t.deviceIndex = link.Attrs().Index
+
 	// Setup our default MTU
 	t.setMTU()
 
@@ -324,20 +358,21 @@ func (t *tun) Activate() error {
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 
+	if err = t.addIPs(link); err != nil {
+		return err
+	}
+
 	// Bring up the interface
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP
 	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to bring the tun device up: %s", err)
 	}
 
-	link, err := netlink.LinkByName(t.Device)
-	if err != nil {
-		return fmt.Errorf("failed to get tun device link: %s", err)
-	}
-	t.deviceIndex = link.Attrs().Index
-
-	if err = t.setDefaultRoute(); err != nil {
-		return err
+	//set route MTU
+	for i := range t.vpnNetworks {
+		if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
+			return fmt.Errorf("failed to set default route MTU: %w", err)
+		}
 	}
 
 	// Set the routes
@@ -363,12 +398,10 @@ func (t *tun) setMTU() {
 	}
 }
 
-func (t *tun) setDefaultRoute() error {
-	// Default route
-
+func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
 	dr := &net.IPNet{
-		IP:   t.cidr.Masked().Addr().AsSlice(),
-		Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+		IP:   cidr.Masked().Addr().AsSlice(),
+		Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
 	}
 
 	nr := netlink.Route{
@@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error {
 		MTU:       t.DefaultMTU,
 		AdvMSS:    t.advMSS(Route{}),
 		Scope:     unix.RT_SCOPE_LINK,
-		Src:       net.IP(t.cidr.Addr().AsSlice()),
+		Src:       net.IP(cidr.Addr().AsSlice()),
 		Protocol:  unix.RTPROT_KERNEL,
 		Table:     unix.RT_TABLE_MAIN,
 		Type:      unix.RTN_UNICAST,
 	}
 	err := netlink.RouteReplace(&nr)
 	if err != nil {
-		return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
+		t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying")
+		//retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument`
+		for i := 0; i < 2; i++ {
+			time.Sleep(100 * time.Millisecond)
+			err = netlink.RouteReplace(&nr)
+			if err == nil {
+				break
+			} else {
+				t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying")
+			}
+		}
+		if err != nil {
+			return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err)
+		}
 	}
 
 	return nil
@@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) {
 	}
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
 func (t *tun) Name() string {
 	return t.Device
 }
@@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	//TODO: IPV6-WORK what if not ok?
 	gwAddr, ok := netip.AddrFromSlice(r.Gw)
 	if !ok {
 		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
@@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	}
 
 	gwAddr = gwAddr.Unmap()
-	if !t.cidr.Contains(gwAddr) {
-		// Gateway isn't in our overlay network, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
-		return
+	withinNetworks := false
+	for i := range t.vpnNetworks {
+		if t.vpnNetworks[i].Contains(gwAddr) {
+			withinNetworks = true
+			break
+		}
 	}
-
-	if x := r.Dst.IP.To4(); x == nil {
-		// Nebula only handles ipv4 on the overlay currently
-		t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4")
+	if !withinNetworks {
+		// Gateway isn't in our overlay network, ignore
+		t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
 		return
 	}
 
@@ -563,11 +605,11 @@ func (t *tun) Close() error {
 	}
 
 	if t.ReadWriteCloser != nil {
-		t.ReadWriteCloser.Close()
+		_ = t.ReadWriteCloser.Close()
 	}
 
 	if t.ioctlFd > 0 {
-		os.NewFile(t.ioctlFd, "ioctlFd").Close()
+		_ = os.NewFile(t.ioctlFd, "ioctlFd").Close()
 	}
 
 	return nil

+ 28 - 17
overlay/tun_netbsd.go

@@ -27,12 +27,12 @@ type ifreqDestroy struct {
 }
 
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 	io.ReadWriteCloser
 }
@@ -58,13 +58,13 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
 	var file *os.File
 	var err error
@@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 		ReadWriteCloser: file,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 }
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -130,8 +130,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {
@@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

+ 29 - 19
overlay/tun_openbsd.go

@@ -21,12 +21,12 @@ import (
 )
 
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 	io.ReadWriteCloser
 
@@ -42,13 +42,13 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 }
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 		ReadWriteCloser: file,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 	}
@@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 }
 
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 		return err
 	}
@@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -138,7 +138,7 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -148,6 +148,16 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
@@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 		if !r.Install {
 			continue
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//TODO: CERT-V2 is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *tun) Name() string {

+ 17 - 17
overlay/tun_tester.go

@@ -16,19 +16,19 @@ import (
 )
 
 type TestTun struct {
-	Device    string
-	cidr      netip.Prefix
-	Routes    []Route
-	routeTree *bart.Table[netip.Addr]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	Routes      []Route
+	routeTree   *bart.Table[netip.Addr]
+	l           *logrus.Logger
 
 	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
-	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
+	_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
 	if err != nil {
 		return nil, err
 	}
@@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
 	}
 
 	return &TestTun{
-		Device:    c.GetString("tun.dev", ""),
-		cidr:      cidr,
-		Routes:    routes,
-		routeTree: routeTree,
-		l:         l,
-		rxPackets: make(chan []byte, 10),
-		TxPackets: make(chan []byte, 10),
+		Device:      c.GetString("tun.dev", ""),
+		vpnNetworks: vpnNetworks,
+		Routes:      routes,
+		routeTree:   routeTree,
+		l:           l,
+		rxPackets:   make(chan []byte, 10),
+		TxPackets:   make(chan []byte, 10),
 	}, nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 
@@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
 	return nil
 }
 
-func (t *TestTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *TestTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 
 func (t *TestTun) Name() string {

+ 0 - 208
overlay/tun_water_windows.go

@@ -1,208 +0,0 @@
-package overlay
-
-import (
-	"fmt"
-	"io"
-	"net"
-	"net/netip"
-	"os/exec"
-	"strconv"
-	"sync/atomic"
-
-	"github.com/gaissmai/bart"
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
-	"github.com/songgao/water"
-)
-
-type waterTun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
-	f         *net.Interface
-	*water.Interface
-}
-
-func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
-	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
-	t := &waterTun{
-		cidr: cidr,
-		MTU:  c.GetInt("tun.mtu", DefaultMTU),
-		l:    l,
-	}
-
-	err := t.reload(c, true)
-	if err != nil {
-		return nil, err
-	}
-
-	c.RegisterReloadCallback(func(c *config.C) {
-		err := t.reload(c, false)
-		if err != nil {
-			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
-		}
-	})
-
-	return t, nil
-}
-
-func (t *waterTun) Activate() error {
-	var err error
-	t.Interface, err = water.New(water.Config{
-		DeviceType: water.TUN,
-		PlatformSpecificParams: water.PlatformSpecificParams{
-			ComponentID: "tap0901",
-			Network:     t.cidr.String(),
-		},
-	})
-	if err != nil {
-		return fmt.Errorf("activate failed: %v", err)
-	}
-
-	t.Device = t.Interface.Name()
-
-	// TODO use syscalls instead of exec.Command
-	err = exec.Command(
-		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
-		fmt.Sprintf("name=%s", t.Device),
-		"source=static",
-		fmt.Sprintf("addr=%s", t.cidr.Addr()),
-		fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
-		"gateway=none",
-	).Run()
-	if err != nil {
-		return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
-	}
-	err = exec.Command(
-		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
-		t.Device,
-		fmt.Sprintf("mtu=%d", t.MTU),
-	).Run()
-	if err != nil {
-		return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
-	}
-
-	t.f, err = net.InterfaceByName(t.Device)
-	if err != nil {
-		return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
-	}
-
-	err = t.addRoutes(false)
-	if err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (t *waterTun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
-	if err != nil {
-		return err
-	}
-
-	if !initial && !change {
-		return nil
-	}
-
-	routeTree, err := makeRouteTree(t.l, routes, false)
-	if err != nil {
-		return err
-	}
-
-	// Teach nebula how to handle the routes before establishing them in the system table
-	oldRoutes := t.Routes.Swap(&routes)
-	t.routeTree.Store(routeTree)
-
-	if !initial {
-		// Remove first, if the system removes a wanted route hopefully it will be re-added next
-		t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
-
-		// Ensure any routes we actually want are installed
-		err = t.addRoutes(true)
-		if err != nil {
-			// Catch any stray logs
-			util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
-		} else {
-			for _, r := range findRemovedRoutes(routes, *oldRoutes) {
-				t.l.WithField("route", r).Info("Removed route")
-			}
-		}
-	}
-
-	return nil
-}
-
-func (t *waterTun) addRoutes(logErrors bool) error {
-	// Path routes
-	routes := *t.Routes.Load()
-	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
-			// We don't allow route MTUs so only install routes with a via
-			continue
-		}
-
-		err := exec.Command(
-			"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
-		).Run()
-
-		if err != nil {
-			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
-			if logErrors {
-				retErr.Log(t.l)
-			} else {
-				return retErr
-			}
-		} else {
-			t.l.WithField("route", r).Info("Added route")
-		}
-	}
-
-	return nil
-}
-
-func (t *waterTun) removeRoutes(routes []Route) {
-	for _, r := range routes {
-		if !r.Install {
-			continue
-		}
-
-		err := exec.Command(
-			"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
-		).Run()
-		if err != nil {
-			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
-		} else {
-			t.l.WithField("route", r).Info("Removed route")
-		}
-	}
-}
-
-func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
-	r, _ := t.routeTree.Load().Lookup(ip)
-	return r
-}
-
-func (t *waterTun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
-func (t *waterTun) Name() string {
-	return t.Device
-}
-
-func (t *waterTun) Close() error {
-	if t.Interface == nil {
-		return nil
-	}
-
-	return t.Interface.Close()
-}
-
-func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
-	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
-}

+ 240 - 13
overlay/tun_windows.go

@@ -4,41 +4,268 @@
 package overlay
 
 import (
+	"crypto"
 	"fmt"
+	"io"
 	"net/netip"
 	"os"
 	"path/filepath"
 	"runtime"
+	"sync/atomic"
 	"syscall"
+	"unsafe"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/wintun"
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
 )
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
+const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
+
+type winTun struct {
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
+
+	tun *wintun.NativeTun
+}
+
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
-	useWintun := true
-	if err := checkWinTunExists(); err != nil {
-		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
-		useWintun = false
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
+	err := checkWinTunExists()
+	if err != nil {
+		return nil, fmt.Errorf("can not load the wintun driver: %w", err)
+	}
+
+	deviceName := c.GetString("tun.dev", "")
+	guid, err := generateGUIDByDeviceName(deviceName)
+	if err != nil {
+		return nil, fmt.Errorf("generate GUID failed: %w", err)
+	}
+
+	t := &winTun{
+		Device:      deviceName,
+		vpnNetworks: vpnNetworks,
+		MTU:         c.GetInt("tun.mtu", DefaultMTU),
+		l:           l,
+	}
+
+	err = t.reload(c, true)
+	if err != nil {
+		return nil, err
 	}
 
-	if useWintun {
-		device, err := newWinTun(c, l, cidr, multiqueue)
+	var tunDevice wintun.Device
+	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
+	if err != nil {
+		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
+		// Trying a second time resolves the issue.
+		l.WithError(err).Debug("Failed to create wintun device, retrying")
+		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
 		if err != nil {
-			return nil, fmt.Errorf("create Wintun interface failed, %w", err)
+			return nil, fmt.Errorf("create TUN device failed: %w", err)
 		}
-		return device, nil
 	}
+	t.tun = tunDevice.(*wintun.NativeTun)
+
+	c.RegisterReloadCallback(func(c *config.C) {
+		err := t.reload(c, false)
+		if err != nil {
+			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+		}
+	})
+
+	return t, nil
+}
 
-	device, err := newWaterTun(c, l, cidr, multiqueue)
+func (t *winTun) reload(c *config.C, initial bool) error {
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
-		return nil, fmt.Errorf("create wintap driver failed, %w", err)
+		return err
+	}
+
+	if !initial && !change {
+		return nil
 	}
-	return device, nil
+
+	routeTree, err := makeRouteTree(t.l, routes, false)
+	if err != nil {
+		return err
+	}
+
+	// Teach nebula how to handle the routes before establishing them in the system table
+	oldRoutes := t.Routes.Swap(&routes)
+	t.routeTree.Store(routeTree)
+
+	if !initial {
+		// Remove first, if the system removes a wanted route hopefully it will be re-added next
+		err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+		if err != nil {
+			util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
+		}
+
+		// Ensure any routes we actually want are installed
+		err = t.addRoutes(true)
+		if err != nil {
+			// Catch any stray logs
+			util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
+		}
+	}
+
+	return nil
+}
+
+func (t *winTun) Activate() error {
+	luid := winipcfg.LUID(t.tun.LUID())
+
+	err := luid.SetIPAddresses(t.vpnNetworks)
+	if err != nil {
+		return fmt.Errorf("failed to set address: %w", err)
+	}
+
+	err = t.addRoutes(false)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (t *winTun) addRoutes(logErrors bool) error {
+	luid := winipcfg.LUID(t.tun.LUID())
+	routes := *t.Routes.Load()
+	foundDefault4 := false
+
+	for _, r := range routes {
+		if !r.Via.IsValid() || !r.Install {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
+		// Add our unsafe route
+		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+			if logErrors {
+				retErr.Log(t.l)
+				continue
+			} else {
+				return retErr
+			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
+		}
+
+		if !foundDefault4 {
+			if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
+				foundDefault4 = true
+			}
+		}
+	}
+
+	ipif, err := luid.IPInterface(windows.AF_INET)
+	if err != nil {
+		return fmt.Errorf("failed to get ip interface: %w", err)
+	}
+
+	ipif.NLMTU = uint32(t.MTU)
+	if foundDefault4 {
+		ipif.UseAutomaticMetric = false
+		ipif.Metric = 0
+	}
+
+	if err := ipif.Set(); err != nil {
+		return fmt.Errorf("failed to set ip interface: %w", err)
+	}
+	return nil
+}
+
+func (t *winTun) removeRoutes(routes []Route) error {
+	luid := winipcfg.LUID(t.tun.LUID())
+
+	for _, r := range routes {
+		if !r.Install {
+			continue
+		}
+
+		err := luid.DeleteRoute(r.Cidr, r.Via)
+		if err != nil {
+			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+		} else {
+			t.l.WithField("route", r).Info("Removed route")
+		}
+	}
+	return nil
+}
+
+func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
+	return r
+}
+
+func (t *winTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
+
+func (t *winTun) Name() string {
+	return t.Device
+}
+
+func (t *winTun) Read(b []byte) (int, error) {
+	return t.tun.Read(b, 0)
+}
+
+func (t *winTun) Write(b []byte) (int, error) {
+	return t.tun.Write(b, 0)
+}
+
+func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
+}
+
+func (t *winTun) Close() error {
+	// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
+	// so to be certain, just remove everything before destroying.
+	luid := winipcfg.LUID(t.tun.LUID())
+	_ = luid.FlushRoutes(windows.AF_INET)
+	_ = luid.FlushIPAddresses(windows.AF_INET)
+
+	_ = luid.FlushRoutes(windows.AF_INET6)
+	_ = luid.FlushIPAddresses(windows.AF_INET6)
+
+	_ = luid.FlushDNS(windows.AF_INET)
+	_ = luid.FlushDNS(windows.AF_INET6)
+
+	return t.tun.Close()
+}
+
+func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
+	// GUID is 128 bit
+	hash := crypto.MD5.New()
+
+	_, err := hash.Write([]byte(tunGUIDLabel))
+	if err != nil {
+		return nil, err
+	}
+
+	_, err = hash.Write([]byte(name))
+	if err != nil {
+		return nil, err
+	}
+
+	sum := hash.Sum(nil)
+
+	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 }
 
 func checkWinTunExists() error {

+ 0 - 252
overlay/tun_wintun_windows.go

@@ -1,252 +0,0 @@
-package overlay
-
-import (
-	"crypto"
-	"fmt"
-	"io"
-	"net/netip"
-	"sync/atomic"
-	"unsafe"
-
-	"github.com/gaissmai/bart"
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
-	"github.com/slackhq/nebula/wintun"
-	"golang.org/x/sys/windows"
-	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
-)
-
-const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
-
-type winTun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
-
-	tun *wintun.NativeTun
-}
-
-func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
-	// GUID is 128 bit
-	hash := crypto.MD5.New()
-
-	_, err := hash.Write([]byte(tunGUIDLabel))
-	if err != nil {
-		return nil, err
-	}
-
-	_, err = hash.Write([]byte(name))
-	if err != nil {
-		return nil, err
-	}
-
-	sum := hash.Sum(nil)
-
-	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
-}
-
-func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
-	deviceName := c.GetString("tun.dev", "")
-	guid, err := generateGUIDByDeviceName(deviceName)
-	if err != nil {
-		return nil, fmt.Errorf("generate GUID failed: %w", err)
-	}
-
-	t := &winTun{
-		Device: deviceName,
-		cidr:   cidr,
-		MTU:    c.GetInt("tun.mtu", DefaultMTU),
-		l:      l,
-	}
-
-	err = t.reload(c, true)
-	if err != nil {
-		return nil, err
-	}
-
-	var tunDevice wintun.Device
-	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
-	if err != nil {
-		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
-		// Trying a second time resolves the issue.
-		l.WithError(err).Debug("Failed to create wintun device, retrying")
-		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
-		if err != nil {
-			return nil, fmt.Errorf("create TUN device failed: %w", err)
-		}
-	}
-	t.tun = tunDevice.(*wintun.NativeTun)
-
-	c.RegisterReloadCallback(func(c *config.C) {
-		err := t.reload(c, false)
-		if err != nil {
-			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
-		}
-	})
-
-	return t, nil
-}
-
-func (t *winTun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
-	if err != nil {
-		return err
-	}
-
-	if !initial && !change {
-		return nil
-	}
-
-	routeTree, err := makeRouteTree(t.l, routes, false)
-	if err != nil {
-		return err
-	}
-
-	// Teach nebula how to handle the routes before establishing them in the system table
-	oldRoutes := t.Routes.Swap(&routes)
-	t.routeTree.Store(routeTree)
-
-	if !initial {
-		// Remove first, if the system removes a wanted route hopefully it will be re-added next
-		err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
-		if err != nil {
-			util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
-		}
-
-		// Ensure any routes we actually want are installed
-		err = t.addRoutes(true)
-		if err != nil {
-			// Catch any stray logs
-			util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
-		}
-	}
-
-	return nil
-}
-
-func (t *winTun) Activate() error {
-	luid := winipcfg.LUID(t.tun.LUID())
-
-	err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
-	if err != nil {
-		return fmt.Errorf("failed to set address: %w", err)
-	}
-
-	err = t.addRoutes(false)
-	if err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (t *winTun) addRoutes(logErrors bool) error {
-	luid := winipcfg.LUID(t.tun.LUID())
-	routes := *t.Routes.Load()
-	foundDefault4 := false
-
-	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
-			// We don't allow route MTUs so only install routes with a via
-			continue
-		}
-
-		// Add our unsafe route
-		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
-		if err != nil {
-			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
-			if logErrors {
-				retErr.Log(t.l)
-				continue
-			} else {
-				return retErr
-			}
-		} else {
-			t.l.WithField("route", r).Info("Added route")
-		}
-
-		if !foundDefault4 {
-			if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
-				foundDefault4 = true
-			}
-		}
-	}
-
-	ipif, err := luid.IPInterface(windows.AF_INET)
-	if err != nil {
-		return fmt.Errorf("failed to get ip interface: %w", err)
-	}
-
-	ipif.NLMTU = uint32(t.MTU)
-	if foundDefault4 {
-		ipif.UseAutomaticMetric = false
-		ipif.Metric = 0
-	}
-
-	if err := ipif.Set(); err != nil {
-		return fmt.Errorf("failed to set ip interface: %w", err)
-	}
-	return nil
-}
-
-func (t *winTun) removeRoutes(routes []Route) error {
-	luid := winipcfg.LUID(t.tun.LUID())
-
-	for _, r := range routes {
-		if !r.Install {
-			continue
-		}
-
-		err := luid.DeleteRoute(r.Cidr, r.Via)
-		if err != nil {
-			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
-		} else {
-			t.l.WithField("route", r).Info("Removed route")
-		}
-	}
-	return nil
-}
-
-func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
-	r, _ := t.routeTree.Load().Lookup(ip)
-	return r
-}
-
-func (t *winTun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
-func (t *winTun) Name() string {
-	return t.Device
-}
-
-func (t *winTun) Read(b []byte) (int, error) {
-	return t.tun.Read(b, 0)
-}
-
-func (t *winTun) Write(b []byte) (int, error) {
-	return t.tun.Write(b, 0)
-}
-
-func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
-	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
-}
-
-func (t *winTun) Close() error {
-	// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
-	// so to be certain, just remove everything before destroying.
-	luid := winipcfg.LUID(t.tun.LUID())
-	_ = luid.FlushRoutes(windows.AF_INET)
-	_ = luid.FlushIPAddresses(windows.AF_INET)
-	/* We don't support IPV6 yet
-	_ = luid.FlushRoutes(windows.AF_INET6)
-	_ = luid.FlushIPAddresses(windows.AF_INET6)
-	*/
-	_ = luid.FlushDNS(windows.AF_INET)
-
-	return t.tun.Close()
-}

+ 6 - 6
overlay/user.go

@@ -8,16 +8,16 @@ import (
 	"github.com/slackhq/nebula/config"
 )
 
-func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
-	return NewUserDevice(tunCidr)
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+	return NewUserDevice(vpnNetworks)
 }
 
-func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
+func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
 	// these pipes guarantee each write/read will match 1:1
 	or, ow := io.Pipe()
 	ir, iw := io.Pipe()
 	return &UserDevice{
-		tunCidr:        tunCidr,
+		vpnNetworks:    vpnNetworks,
 		outboundReader: or,
 		outboundWriter: ow,
 		inboundReader:  ir,
@@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
 }
 
 type UserDevice struct {
-	tunCidr netip.Prefix
+	vpnNetworks []netip.Prefix
 
 	outboundReader *io.PipeReader
 	outboundWriter *io.PipeWriter
@@ -38,7 +38,7 @@ type UserDevice struct {
 func (d *UserDevice) Activate() error {
 	return nil
 }
-func (d *UserDevice) Cidr() netip.Prefix                { return d.tunCidr }
+func (d *UserDevice) Networks() []netip.Prefix          { return d.vpnNetworks }
 func (d *UserDevice) Name() string                      { return "faketun0" }
 func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {

+ 328 - 68
pki.go

@@ -1,13 +1,19 @@
 package nebula
 
 import (
+	"encoding/binary"
+	"encoding/json"
 	"errors"
 	"fmt"
+	"net"
+	"net/netip"
 	"os"
+	"slices"
 	"strings"
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
@@ -21,12 +27,22 @@ type PKI struct {
 }
 
 type CertState struct {
-	Certificate         cert.Certificate
-	RawCertificate      []byte
-	RawCertificateNoKey []byte
-	PublicKey           []byte
-	PrivateKey          []byte
-	pkcs11Backed        bool
+	v1Cert           cert.Certificate
+	v1HandshakeBytes []byte
+
+	v2Cert           cert.Certificate
+	v2HandshakeBytes []byte
+
+	defaultVersion cert.Version
+	privateKey     []byte
+	pkcs11Backed   bool
+	cipher         string
+
+	myVpnNetworks            []netip.Prefix
+	myVpnNetworksTable       *bart.Table[struct{}]
+	myVpnAddrs               []netip.Addr
+	myVpnAddrsTable          *bart.Table[struct{}]
+	myVpnBroadcastAddrsTable *bart.Table[struct{}]
 }
 
 func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -46,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
 	return pki, nil
 }
 
-func (p *PKI) GetCertState() *CertState {
-	return p.cs.Load()
-}
-
 func (p *PKI) GetCAPool() *cert.CAPool {
 	return p.caPool.Load()
 }
 
+func (p *PKI) getCertState() *CertState {
+	return p.cs.Load()
+}
+
 func (p *PKI) reload(c *config.C, initial bool) error {
-	err := p.reloadCert(c, initial)
+	err := p.reloadCerts(c, initial)
 	if err != nil {
 		if initial {
 			return err
@@ -74,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
-	cs, err := newCertStateFromConfig(c)
+func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
+	newState, err := newCertStateFromConfig(c)
 	if err != nil {
 		return util.NewContextualError("Could not load client cert", nil, err)
 	}
 
 	if !initial {
-		//TODO: include check for mask equality as well
+		currentState := p.cs.Load()
+		if newState.v1Cert != nil {
+			if currentState.v1Cert == nil {
+				return util.NewContextualError("v1 certificate was added, restart required", nil, err)
+			}
+
+			// did IP in cert change? if so, don't set
+			if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
+				return util.NewContextualError(
+					"Networks in new cert was different from old",
+					m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
+					nil,
+				)
+			}
+
+			if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
+				return util.NewContextualError(
+					"Curve in new cert was different from old",
+					m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
+					nil,
+				)
+			}
+
+		} else if currentState.v1Cert != nil {
+			//TODO: CERT-V2 we should be able to tear this down
+			return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
+		}
+
+		if newState.v2Cert != nil {
+			if currentState.v2Cert == nil {
+				return util.NewContextualError("v2 certificate was added, restart required", nil, err)
+			}
+
+			// did IP in cert change? if so, don't set
+			if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
+				return util.NewContextualError(
+					"Networks in new cert was different from old",
+					m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
+					nil,
+				)
+			}
 
-		// did IP in cert change? if so, don't set
-		currentCert := p.cs.Load().Certificate
-		oldIPs := currentCert.Networks()
-		newIPs := cs.Certificate.Networks()
-		if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
+			if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+				return util.NewContextualError(
+					"Curve in new cert was different from old",
+					m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
+					nil,
+				)
+			}
+
+		} else if currentState.v2Cert != nil {
+			return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
+		}
+
+		// Cipher cant be hot swapped so just leave it at what it was before
+		newState.cipher = currentState.cipher
+
+	} else {
+		newState.cipher = c.GetString("cipher", "aes")
+		//TODO: this sucks and we should make it not a global
+		switch newState.cipher {
+		case "aes":
+			noiseEndianness = binary.BigEndian
+		case "chachapoly":
+			noiseEndianness = binary.LittleEndian
+		default:
 			return util.NewContextualError(
-				"Networks in new cert was different from old",
-				m{"new_network": newIPs[0], "old_network": oldIPs[0]},
+				"unknown cipher",
+				m{"cipher": newState.cipher},
 				nil,
 			)
 		}
 	}
 
-	p.cs.Store(cs)
+	p.cs.Store(newState)
+
+	//TODO: CERT-V2 newState needs a stringer that does json
 	if initial {
-		p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
+		p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
 	} else {
-		p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
+		p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
 	}
 	return nil
 }
@@ -116,55 +193,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
 	return nil
 }
 
-func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
-	// Marshal the certificate to ensure it is valid
-	rawCertificate, err := certificate.Marshal()
-	if err != nil {
-		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
+func (cs *CertState) GetDefaultCertificate() cert.Certificate {
+	c := cs.getCertificate(cs.defaultVersion)
+	if c == nil {
+		panic("No default certificate found")
 	}
+	return c
+}
 
-	publicKey := certificate.PublicKey()
-	cs := &CertState{
-		RawCertificate: rawCertificate,
-		Certificate:    certificate,
-		PrivateKey:     privateKey,
-		PublicKey:      publicKey,
-		pkcs11Backed:   pkcs11backed,
+func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
+	switch v {
+	case cert.Version1:
+		return cs.v1Cert
+	case cert.Version2:
+		return cs.v2Cert
 	}
 
-	rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
-	if err != nil {
-		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
+	return nil
+}
+
+// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
+// Callers must check if the return []byte is nil.
+func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
+	switch v {
+	case cert.Version1:
+		return cs.v1HandshakeBytes
+	case cert.Version2:
+		return cs.v2HandshakeBytes
+	default:
+		return nil
 	}
-	cs.RawCertificateNoKey = rawCertNoKey
+}
 
-	return cs, nil
+func (cs *CertState) String() string {
+	b, err := cs.MarshalJSON()
+	if err != nil {
+		return fmt.Sprintf("error marshaling certificate state: %v", err)
+	}
+	return string(b)
 }
 
-func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
-	var pemPrivateKey []byte
-	if strings.Contains(privPathOrPEM, "-----BEGIN") {
-		pemPrivateKey = []byte(privPathOrPEM)
-		privPathOrPEM = "<inline>"
-		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+func (cs *CertState) MarshalJSON() ([]byte, error) {
+	msg := []json.RawMessage{}
+	if cs.v1Cert != nil {
+		b, err := cs.v1Cert.MarshalJSON()
 		if err != nil {
-			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+			return nil, err
 		}
-	} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
-		rawKey = []byte(privPathOrPEM)
-		return rawKey, cert.Curve_P256, true, nil
-	} else {
-		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
+		msg = append(msg, b)
+	}
+
+	if cs.v2Cert != nil {
+		b, err := cs.v2Cert.MarshalJSON()
 		if err != nil {
-			return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
-		}
-		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
-		if err != nil {
-			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+			return nil, err
 		}
+		msg = append(msg, b)
 	}
 
-	return
+	return json.Marshal(msg)
 }
 
 func newCertStateFromConfig(c *config.C) (*CertState, error) {
@@ -198,24 +285,197 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
 		}
 	}
 
-	nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
+	var crt, v1, v2 cert.Certificate
+	for {
+		// Load the certificate
+		crt, rawCert, err = loadCertificate(rawCert)
+		if err != nil {
+			return nil, err
+		}
+
+		switch crt.Version() {
+		case cert.Version1:
+			if v1 != nil {
+				return nil, fmt.Errorf("v1 certificate already found in pki.cert")
+			}
+			v1 = crt
+		case cert.Version2:
+			if v2 != nil {
+				return nil, fmt.Errorf("v2 certificate already found in pki.cert")
+			}
+			v2 = crt
+		default:
+			return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
+		}
+
+		if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
+			break
+		}
+	}
+
+	if v1 == nil && v2 == nil {
+		return nil, errors.New("no certificates found in pki.cert")
+	}
+
+	useDefaultVersion := uint32(1)
+	if v1 == nil {
+		// The only condition that requires v2 as the default is if only a v2 certificate is present
+		// We do this to avoid having to configure it specifically in the config file
+		useDefaultVersion = 2
+	}
+
+	rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
+	var defaultVersion cert.Version
+	switch rawDefaultVersion {
+	case 1:
+		if v1 == nil {
+			return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
+		}
+		defaultVersion = cert.Version1
+	case 2:
+		defaultVersion = cert.Version2
+	default:
+		return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
+	}
+
+	return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
+}
+
+func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
+	cs := CertState{
+		privateKey:               privateKey,
+		pkcs11Backed:             pkcs11backed,
+		myVpnNetworksTable:       new(bart.Table[struct{}]),
+		myVpnAddrsTable:          new(bart.Table[struct{}]),
+		myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
+	}
+
+	if v1 != nil && v2 != nil {
+		if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
+			return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
+		}
+
+		if v1.Curve() != v2.Curve() {
+			return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
+		}
+
+		//TODO: CERT-V2 make sure v2 has v1s address
+
+		cs.defaultVersion = dv
+	}
+
+	if v1 != nil {
+		if pkcs11backed {
+			//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
+		} else {
+			if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
+				return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+			}
+		}
+
+		v1hs, err := v1.MarshalForHandshakes()
+		if err != nil {
+			return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
+		}
+		cs.v1Cert = v1
+		cs.v1HandshakeBytes = v1hs
+
+		if cs.defaultVersion == 0 {
+			cs.defaultVersion = cert.Version1
+		}
+	}
+
+	if v2 != nil {
+		if pkcs11backed {
+			//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
+		} else {
+			if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
+				return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+			}
+		}
+
+		v2hs, err := v2.MarshalForHandshakes()
+		if err != nil {
+			return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
+		}
+		cs.v2Cert = v2
+		cs.v2HandshakeBytes = v2hs
+
+		if cs.defaultVersion == 0 {
+			cs.defaultVersion = cert.Version2
+		}
+	}
+
+	var crt cert.Certificate
+	crt = cs.getCertificate(cert.Version2)
+	if crt == nil {
+		// v2 certificates are a superset, only look at v1 if its all we have
+		crt = cs.getCertificate(cert.Version1)
+	}
+
+	for _, network := range crt.Networks() {
+		cs.myVpnNetworks = append(cs.myVpnNetworks, network)
+		cs.myVpnNetworksTable.Insert(network, struct{}{})
+
+		cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
+		cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
+
+		if network.Addr().Is4() {
+			addr := network.Masked().Addr().As4()
+			mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
+			binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
+			cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
+		}
+	}
+
+	return &cs, nil
+}
+
+func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
+	var pemPrivateKey []byte
+	if strings.Contains(privPathOrPEM, "-----BEGIN") {
+		pemPrivateKey = []byte(privPathOrPEM)
+		privPathOrPEM = "<inline>"
+		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+		}
+	} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
+		rawKey = []byte(privPathOrPEM)
+		return rawKey, cert.Curve_P256, true, nil
+	} else {
+		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
+		}
+		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+		}
+	}
+
+	return
+}
+
+func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
+	c, b, err := cert.UnmarshalCertificateFromPEM(b)
 	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
+		return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
 	}
 
-	if nebulaCert.Expired(time.Now()) {
-		return nil, fmt.Errorf("nebula certificate for this host is expired")
+	if c.Expired(time.Now()) {
+		return nil, b, fmt.Errorf("nebula certificate for this host is expired")
 	}
 
-	if len(nebulaCert.Networks()) == 0 {
-		return nil, fmt.Errorf("no networks encoded in certificate")
+	if len(c.Networks()) == 0 {
+		return nil, b, fmt.Errorf("no networks encoded in certificate")
 	}
 
-	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
-		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+	if c.IsCA() {
+		return nil, b, fmt.Errorf("host certificate is a CA certificate")
 	}
 
-	return newCertState(nebulaCert, isPkcs11, rawKey)
+	return c, b, nil
 }
 
 func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {

+ 166 - 130
relay_manager.go

@@ -9,6 +9,7 @@ import (
 	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 )
@@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
 				Type:       relayType,
 				State:      state,
 				LocalIndex: index,
-				PeerIp:     vpnIp,
+				PeerAddr:   vpnIp,
 			}
 
 			if remoteIdx != nil {
@@ -91,40 +92,71 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
 func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
 	relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
 	if !ok {
-		rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp,
+		fields := logrus.Fields{
+			"relay":               relayHostInfo.vpnAddrs[0],
 			"initiatorRelayIndex": m.InitiatorRelayIndex,
-			"relayFrom":           m.RelayFromIp,
-			"relayTo":             m.RelayToIp}).Info("relayManager failed to update relay")
+		}
+
+		if m.RelayFromAddr == nil {
+			fields["relayFrom"] = m.OldRelayFromAddr
+		} else {
+			fields["relayFrom"] = m.RelayFromAddr
+		}
+
+		if m.RelayToAddr == nil {
+			fields["relayTo"] = m.OldRelayToAddr
+		} else {
+			fields["relayTo"] = m.RelayToAddr
+		}
+
+		rm.l.WithFields(fields).Info("relayManager failed to update relay")
 		return nil, fmt.Errorf("unknown relay")
 	}
 
 	return relay, nil
 }
 
-func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) {
+func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
+	msg := &NebulaControl{}
+	err := msg.Unmarshal(d)
+	if err != nil {
+		h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
+		return
+	}
+
+	var v cert.Version
+	if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 {
+		v = cert.Version1
 
-	switch m.Type {
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr)
+		msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
+
+		binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr)
+		msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
+	} else {
+		v = cert.Version2
+	}
+
+	switch msg.Type {
 	case NebulaControl_CreateRelayRequest:
-		rm.handleCreateRelayRequest(h, f, m)
+		rm.handleCreateRelayRequest(v, h, f, msg)
 	case NebulaControl_CreateRelayResponse:
-		rm.handleCreateRelayResponse(h, f, m)
+		rm.handleCreateRelayResponse(v, h, f, msg)
 	}
-
 }
 
-func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
+func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
 	rm.l.WithFields(logrus.Fields{
-		"relayFrom":           m.RelayFromIp,
-		"relayTo":             m.RelayToIp,
+		"relayFrom":           protoAddrToNetAddr(m.RelayFromAddr),
+		"relayTo":             protoAddrToNetAddr(m.RelayToAddr),
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
-		"vpnIp":               h.vpnIp}).
+		"vpnAddrs":            h.vpnAddrs}).
 		Info("handleCreateRelayResponse")
-	target := m.RelayToIp
-	//TODO: IPV6-WORK
-	b := [4]byte{}
-	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
-	targetAddr := netip.AddrFrom4(b)
+
+	target := m.RelayToAddr
+	targetAddr := protoAddrToNetAddr(target)
 
 	relay, err := rm.EstablishRelay(h, m)
 	if err != nil {
@@ -136,68 +168,88 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		return
 	}
 	// I'm the middle man. Let the initiator know that the I've established the relay they requested.
-	peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
+	peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
 	if peerHostInfo == nil {
-		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
+		rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
 		return
 	}
 	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
 	if !ok {
-		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
+		rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
 		return
 	}
-	if peerRelay.State == PeerRequested {
-		//TODO: IPV6-WORK
-		b = peerHostInfo.vpnIp.As4()
-		peerRelay.State = Established
+	switch peerRelay.State {
+	case Requested:
+		// I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer
+		// to respond to complete the connection.
+	case PeerRequested, Disestablished, Established:
+		peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established)
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: peerRelay.LocalIndex,
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
-			RelayFromIp:         binary.BigEndian.Uint32(b[:]),
-			RelayToIp:           uint32(target),
 		}
+
+		if v == cert.Version1 {
+			peer := peerHostInfo.vpnAddrs[0]
+			if !peer.Is4() {
+				rm.l.WithField("relayFrom", peer).
+					WithField("relayTo", target).
+					WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
+					WithField("responderRelayIndex", resp.ResponderRelayIndex).
+					WithField("vpnAddrs", peerHostInfo.vpnAddrs).
+					Error("Refusing to CreateRelayResponse for a v1 relay with an ipv6 address")
+				return
+			}
+
+			b := peer.As4()
+			resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = targetAddr.As4()
+			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		} else {
+			resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
+			resp.RelayToAddr = target
+		}
+
 		msg, err := resp.Marshal()
 		if err != nil {
-			rm.l.
-				WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
+			rm.l.WithError(err).
+				Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           resp.RelayFromIp,
-				"relayTo":             resp.RelayToIp,
+				"relayFrom":           resp.RelayFromAddr,
+				"relayTo":             resp.RelayToAddr,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
-				"vpnIp":               peerHostInfo.vpnIp}).
+				"vpnAddrs":            peerHostInfo.vpnAddrs}).
 				Info("send CreateRelayResponse")
 		}
 	}
 }
 
-func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
-	//TODO: IPV6-WORK
-	b := [4]byte{}
-	binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
-	from := netip.AddrFrom4(b)
-
-	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
-	target := netip.AddrFrom4(b)
+func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
+	from := protoAddrToNetAddr(m.RelayFromAddr)
+	target := protoAddrToNetAddr(m.RelayToAddr)
 
 	logMsg := rm.l.WithFields(logrus.Fields{
 		"relayFrom":           from,
 		"relayTo":             target,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
-		"vpnIp":               h.vpnIp})
+		"vpnAddrs":            h.vpnAddrs})
 
 	logMsg.Info("handleCreateRelayRequest")
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// an issue migrating relays over to newly re-handshaked host info objects.
-	if from == f.myVpnNet.Addr() {
+	_, found := f.myVpnAddrsTable.Lookup(from)
+	if found {
 		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		return
 	}
+
 	// Is the target of the relay me?
-	if target == f.myVpnNet.Addr() {
+	_, found = f.myVpnAddrsTable.Lookup(target)
+	if found {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		if ok {
 			switch existingRelay.State {
@@ -215,6 +267,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 						"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
 					return
 				}
+			case Disestablished:
+				if existingRelay.RemoteIndex != m.InitiatorRelayIndex {
+					// We got a brand new Relay request, because its index is different than what we saw before.
+					// This should never happen. The peer should never change an index, once created.
+					logMsg.WithFields(logrus.Fields{
+						"existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
+					return
+				}
+				// Mark the relay as 'Established' because it's safe to use again
+				h.relayState.UpdateRelayForByIpState(from, Established)
+			case PeerRequested:
+				// I should never be in this state, because I am terminal, not forwarding.
+				logMsg.WithFields(logrus.Fields{
+					"existingRemoteIndex": existingRelay.RemoteIndex,
+					"state":               existingRelay.State}).Error("Unexpected Relay State found")
 			}
 		} else {
 			_, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established)
@@ -226,21 +293,26 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 
 		relay, ok := h.relayState.QueryRelayForByIp(from)
 		if !ok {
-			logMsg.Error("Relay State not found")
+			logMsg.WithField("from", from).Error("Relay State not found")
 			return
 		}
 
-		//TODO: IPV6-WORK
-		fromB := from.As4()
-		targetB := target.As4()
-
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: relay.LocalIndex,
 			InitiatorRelayIndex: relay.RemoteIndex,
-			RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
-			RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 		}
+
+		if v == cert.Version1 {
+			b := from.As4()
+			resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = target.As4()
+			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		} else {
+			resp.RelayFromAddr = netAddrToProtoAddr(from)
+			resp.RelayToAddr = netAddrToProtoAddr(target)
+		}
+
 		msg, err := resp.Marshal()
 		if err != nil {
 			logMsg.
@@ -248,12 +320,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
-				//TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
 				"relayFrom":           from,
 				"relayTo":             target,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
-				"vpnIp":               h.vpnIp}).
+				"vpnAddrs":            h.vpnAddrs}).
 				Info("send CreateRelayResponse")
 		}
 		return
@@ -262,7 +333,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		if !rm.GetAmRelay() {
 			return
 		}
-		peer := rm.hostmap.QueryVpnIp(target)
+		peer := rm.hostmap.QueryVpnAddr(target)
 		if peer == nil {
 			// Try to establish a connection to this host. If we get a future relay request,
 			// we'll be ready!
@@ -273,104 +344,69 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			// Only create relays to peers for whom I have a direct connection
 			return
 		}
-		sendCreateRequest := false
 		var index uint32
 		var err error
 		targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
 		if ok {
 			index = targetRelay.LocalIndex
-			if targetRelay.State == Requested {
-				sendCreateRequest = true
-			}
 		} else {
 			// Allocate an index in the hostMap for this relay peer
 			index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested)
 			if err != nil {
 				return
 			}
-			sendCreateRequest = true
 		}
-		if sendCreateRequest {
-			//TODO: IPV6-WORK
-			fromB := h.vpnIp.As4()
-			targetB := target.As4()
-
-			// Send a CreateRelayRequest to the peer.
-			req := NebulaControl{
-				Type:                NebulaControl_CreateRelayRequest,
-				InitiatorRelayIndex: index,
-				RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
-				RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
-			}
-			msg, err := req.Marshal()
-			if err != nil {
-				logMsg.
-					WithError(err).Error("relayManager Failed to marshal Control message to create relay")
-			} else {
-				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
-				rm.l.WithFields(logrus.Fields{
-					//TODO: IPV6-WORK another lazy used to use the req object
-					"relayFrom":           h.vpnIp,
-					"relayTo":             target,
-					"initiatorRelayIndex": req.InitiatorRelayIndex,
-					"responderRelayIndex": req.ResponderRelayIndex,
-					"vpnIp":               target}).
-					Info("send CreateRelayRequest")
+		peer.relayState.UpdateRelayForByIpState(from, Requested)
+		// Send a CreateRelayRequest to the peer.
+		req := NebulaControl{
+			Type:                NebulaControl_CreateRelayRequest,
+			InitiatorRelayIndex: index,
+		}
+
+		if v == cert.Version1 {
+			if !h.vpnAddrs[0].Is4() {
+				rm.l.WithField("relayFrom", h.vpnAddrs[0]).
+					WithField("relayTo", target).
+					WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
+					WithField("responderRelayIndex", req.ResponderRelayIndex).
+					WithField("vpnAddr", target).
+					Error("Refusing to CreateRelayRequest for a v1 relay with an ipv6 address")
+				return
 			}
+
+			b := h.vpnAddrs[0].As4()
+			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = target.As4()
+			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		} else {
+			req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
+			req.RelayToAddr = netAddrToProtoAddr(target)
 		}
+
+		msg, err := req.Marshal()
+		if err != nil {
+			logMsg.
+				WithError(err).Error("relayManager Failed to marshal Control message to create relay")
+		} else {
+			f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
+			rm.l.WithFields(logrus.Fields{
+				"relayFrom":           h.vpnAddrs[0],
+				"relayTo":             target,
+				"initiatorRelayIndex": req.InitiatorRelayIndex,
+				"responderRelayIndex": req.ResponderRelayIndex,
+				"vpnAddr":             target}).
+				Info("send CreateRelayRequest")
+		}
+
 		// Also track the half-created Relay state just received
-		relay, ok := h.relayState.QueryRelayForByIp(target)
+		_, ok = h.relayState.QueryRelayForByIp(target)
 		if !ok {
-			// Add the relay
-			state := PeerRequested
-			if targetRelay != nil && targetRelay.State == Established {
-				state = Established
-			}
-			_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state)
+			_, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested)
 			if err != nil {
 				logMsg.
 					WithError(err).Error("relayManager Failed to allocate a local index for relay")
 				return
 			}
-		} else {
-			switch relay.State {
-			case Established:
-				if relay.RemoteIndex != m.InitiatorRelayIndex {
-					// We got a brand new Relay request, because its index is different than what we saw before.
-					// This should never happen. The peer should never change an index, once created.
-					logMsg.WithFields(logrus.Fields{
-						"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
-					return
-				}
-				//TODO: IPV6-WORK
-				fromB := h.vpnIp.As4()
-				targetB := target.As4()
-				resp := NebulaControl{
-					Type:                NebulaControl_CreateRelayResponse,
-					ResponderRelayIndex: relay.LocalIndex,
-					InitiatorRelayIndex: relay.RemoteIndex,
-					RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
-					RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
-				}
-				msg, err := resp.Marshal()
-				if err != nil {
-					rm.l.
-						WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
-				} else {
-					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
-					rm.l.WithFields(logrus.Fields{
-						//TODO: IPV6-WORK more lazy, used to use resp object
-						"relayFrom":           h.vpnIp,
-						"relayTo":             target,
-						"initiatorRelayIndex": resp.InitiatorRelayIndex,
-						"responderRelayIndex": resp.ResponderRelayIndex,
-						"vpnIp":               h.vpnIp}).
-						Info("send CreateRelayResponse")
-				}
-
-			case Requested:
-				// Keep waiting for the other relay to complete
-			}
 		}
 	}
 }

+ 51 - 35
remote_list.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"net"
 	"net/netip"
+	"slices"
 	"sort"
 	"strconv"
 	"sync"
@@ -17,8 +18,8 @@ import (
 type forEachFunc func(addr netip.AddrPort, preferred bool)
 
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool
+type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool
 
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // The string key is the owners vpnIp
@@ -32,9 +33,6 @@ type Cache struct {
 	Relay    []netip.Addr     `json:"relay"`
 }
 
-//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
-// We will never clean learned/reported information for them as it stands today
-
 // cache is an internal struct that splits v4 and v6 addresses inside the cache map
 type cache struct {
 	v4    *cacheV4
@@ -48,14 +46,14 @@ type cacheRelay struct {
 
 // cacheV4 stores learned and reported ipv4 records under cache
 type cacheV4 struct {
-	learned  *Ip4AndPort
-	reported []*Ip4AndPort
+	learned  *V4AddrPort
+	reported []*V4AddrPort
 }
 
 // cacheV4 stores learned and reported ipv6 records under cache
 type cacheV6 struct {
-	learned  *Ip6AndPort
-	reported []*Ip6AndPort
+	learned  *V6AddrPort
+	reported []*V6AddrPort
 }
 
 type hostnamePort struct {
@@ -170,7 +168,7 @@ func (hr *hostnamesResults) Cancel() {
 	}
 }
 
-func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
+func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
 	var retSlice []netip.AddrPort
 	if hr != nil {
 		p := hr.ips.Load()
@@ -189,6 +187,9 @@ type RemoteList struct {
 	// Every interaction with internals requires a lock!
 	sync.RWMutex
 
+	// The full list of vpn addresses assigned to this host
+	vpnAddrs []netip.Addr
+
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
 	addrs []netip.AddrPort
 
@@ -212,13 +213,16 @@ type RemoteList struct {
 }
 
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
-	return &RemoteList{
+func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
+	r := &RemoteList{
+		vpnAddrs:  make([]netip.Addr, len(vpnAddrs)),
 		addrs:     make([]netip.AddrPort, 0),
 		relays:    make([]netip.Addr, 0),
 		cache:     make(map[netip.Addr]*cache),
 		shouldAdd: shouldAdd,
 	}
+	copy(r.vpnAddrs, vpnAddrs)
+	return r
 }
 
 func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
@@ -268,14 +272,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort
 // LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
-// TODO: this needs to support the allow list list
 func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
 	r.Lock()
 	defer r.Unlock()
 	if remote.Addr().Is4() {
-		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
+		r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port()))
 	} else {
-		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
+		r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port()))
 	}
 }
 
@@ -304,21 +307,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
 
 		if mc.v4 != nil {
 			if mc.v4.learned != nil {
-				c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
+				c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned))
 			}
 
 			for _, a := range mc.v4.reported {
-				c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
+				c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a))
 			}
 		}
 
 		if mc.v6 != nil {
 			if mc.v6.learned != nil {
-				c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
+				c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned))
 			}
 
 			for _, a := range mc.v6.reported {
-				c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
+				c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a))
 			}
 		}
 
@@ -379,7 +382,6 @@ func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
 	defer r.Unlock()
 
 	// Only rebuild if the cache changed
-	//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
 	if r.shouldRebuild {
 		r.unlockedCollect()
 		r.shouldRebuild = false
@@ -401,14 +403,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
 
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 }
 
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
@@ -423,7 +425,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPor
 	}
 }
 
-func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
+func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeRelay(ownerVpnIp)
 
@@ -436,12 +438,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A
 
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 	// We are doing the easy append because this is rarely called
-	c.reported = append([]*Ip4AndPort{to}, c.reported...)
+	c.reported = append([]*V4AddrPort{to}, c.reported...)
 	if len(c.reported) > MaxRemotes {
 		c.reported = c.reported[:MaxRemotes]
 	}
@@ -449,14 +451,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
 
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 }
 
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
@@ -473,12 +475,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor
 
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 	// We are doing the easy append because this is rarely called
-	c.reported = append([]*Ip6AndPort{to}, c.reported...)
+	c.reported = append([]*V6AddrPort{to}, c.reported...)
 	if len(c.reported) > MaxRemotes {
 		c.reported = c.reported[:MaxRemotes]
 	}
@@ -536,14 +538,14 @@ func (r *RemoteList) unlockedCollect() {
 	for _, c := range r.cache {
 		if c.v4 != nil {
 			if c.v4.learned != nil {
-				u := AddrPortFromIp4AndPort(c.v4.learned)
+				u := protoV4AddrPortToNetAddrPort(c.v4.learned)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
 			}
 
 			for _, v := range c.v4.reported {
-				u := AddrPortFromIp4AndPort(v)
+				u := protoV4AddrPortToNetAddrPort(v)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
@@ -552,14 +554,14 @@ func (r *RemoteList) unlockedCollect() {
 
 		if c.v6 != nil {
 			if c.v6.learned != nil {
-				u := AddrPortFromIp6AndPort(c.v6.learned)
+				u := protoV6AddrPortToNetAddrPort(c.v6.learned)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
 			}
 
 			for _, v := range c.v6.reported {
-				u := AddrPortFromIp6AndPort(v)
+				u := protoV6AddrPortToNetAddrPort(v)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
@@ -573,7 +575,7 @@ func (r *RemoteList) unlockedCollect() {
 		}
 	}
 
-	dnsAddrs := r.hr.GetIPs()
+	dnsAddrs := r.hr.GetAddrs()
 	for _, addr := range dnsAddrs {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
 			if !r.unlockedIsBad(addr) {
@@ -589,6 +591,21 @@ func (r *RemoteList) unlockedCollect() {
 
 // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
 func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
+	// Use a map to deduplicate any relay addresses
+	dedupedRelays := map[netip.Addr]struct{}{}
+	for _, relay := range r.relays {
+		dedupedRelays[relay] = struct{}{}
+	}
+	r.relays = r.relays[:0]
+	for relay := range dedupedRelays {
+		r.relays = append(r.relays, relay)
+	}
+	// Put them in a somewhat consistent order after de-duplication
+	slices.SortFunc(r.relays, func(a, b netip.Addr) int {
+		return a.Compare(b)
+	})
+
+	// Now the addrs
 	n := len(r.addrs)
 	if n < 2 {
 		return
@@ -687,7 +704,6 @@ func minInt(a, b int) int {
 
 // isPreferred returns true of the ip is contained in the preferredRanges list
 func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
-	//TODO: this would be better in a CIDR6Tree
 	for _, p := range preferredRanges {
 		if p.Contains(ip) {
 			return true

+ 35 - 20
remote_list_test.go

@@ -9,11 +9,11 @@ import (
 )
 
 func TestRemoteList_Rebuild(t *testing.T) {
-	rl := NewRemoteList(nil)
+	rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
 	rl.unlockedSetV4(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip4AndPort{
+		[]*V4AddrPort{
 			newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
@@ -25,20 +25,30 @@ func TestRemoteList_Rebuild(t *testing.T) {
 			newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
 			newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
 		netip.MustParseAddr("0.0.0.1"),
 		netip.MustParseAddr("0.0.0.1"),
-		[]*Ip6AndPort{
+		[]*V6AddrPort{
 			newIp6AndPortFromString("[1::1]:1"), // this is duped
 			newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 			newIp6AndPortFromString("[1::1]:2"), // this is a dupe
 		},
-		func(netip.Addr, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *V6AddrPort) bool { return true },
+	)
+
+	rl.unlockedSetRelay(
+		netip.MustParseAddr("0.0.0.1"),
+		[]netip.Addr{
+			netip.MustParseAddr("1::1"),
+			netip.MustParseAddr("1.2.3.4"),
+			netip.MustParseAddr("1.2.3.4"),
+			netip.MustParseAddr("1::1"),
+		},
 	)
 
 	rl.Rebuild([]netip.Prefix{})
@@ -76,6 +86,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
 	assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
 	assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
 
+	// assert relay deduplicated
+	assert.Len(t, rl.relays, 2)
+	assert.Equal(t, "1.2.3.4", rl.relays[0].String())
+	assert.Equal(t, "1::1", rl.relays[1].String())
+
 	// Ensure we can hoist a specific ipv4 range over anything else
 	rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
@@ -98,11 +113,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
 }
 
 func BenchmarkFullRebuild(b *testing.B) {
-	rl := NewRemoteList(nil)
+	rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
 	rl.unlockedSetV4(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip4AndPort{
+		[]*V4AddrPort{
 			newIp4AndPortFromString("70.199.182.92:1475"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"),
@@ -112,19 +127,19 @@ func BenchmarkFullRebuild(b *testing.B) {
 			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
 			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip6AndPort{
+		[]*V6AddrPort{
 			newIp6AndPortFromString("[1::1]:1"),
 			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
-		func(netip.Addr, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *V6AddrPort) bool { return true },
 	)
 
 	b.Run("no preferred", func(b *testing.B) {
@@ -160,11 +175,11 @@ func BenchmarkFullRebuild(b *testing.B) {
 }
 
 func BenchmarkSortRebuild(b *testing.B) {
-	rl := NewRemoteList(nil)
+	rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
 	rl.unlockedSetV4(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip4AndPort{
+		[]*V4AddrPort{
 			newIp4AndPortFromString("70.199.182.92:1475"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"),
@@ -174,19 +189,19 @@ func BenchmarkSortRebuild(b *testing.B) {
 			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
 			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip6AndPort{
+		[]*V6AddrPort{
 			newIp6AndPortFromString("[1::1]:1"),
 			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
-		func(netip.Addr, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *V6AddrPort) bool { return true },
 	)
 
 	b.Run("no preferred", func(b *testing.B) {
@@ -224,19 +239,19 @@ func BenchmarkSortRebuild(b *testing.B) {
 	})
 }
 
-func newIp4AndPortFromString(s string) *Ip4AndPort {
+func newIp4AndPortFromString(s string) *V4AddrPort {
 	a := netip.MustParseAddrPort(s)
 	v4Addr := a.Addr().As4()
-	return &Ip4AndPort{
-		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+	return &V4AddrPort{
+		Addr: binary.BigEndian.Uint32(v4Addr[:]),
 		Port: uint32(a.Port()),
 	}
 }
 
-func newIp6AndPortFromString(s string) *Ip6AndPort {
+func newIp6AndPortFromString(s string) *V6AddrPort {
 	a := netip.MustParseAddrPort(s)
 	v6Addr := a.Addr().As16()
-	return &Ip6AndPort{
+	return &V6AddrPort{
 		Hi:   binary.BigEndian.Uint64(v6Addr[:8]),
 		Lo:   binary.BigEndian.Uint64(v6Addr[8:]),
 		Port: uint32(a.Port()),

+ 2 - 2
service/service.go

@@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) {
 		},
 	})
 
-	ipNet := device.Cidr()
+	ipNet := device.Networks()
 	pa := tcpip.ProtocolAddress{
-		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
+		AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
 		Protocol:          ipv4.ProtocolNumber,
 	}
 	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

+ 3 - 3
service/service_test.go

@@ -10,8 +10,8 @@ import (
 
 	"dario.cat/mergo"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/e2e"
 	"golang.org/x/sync/errgroup"
 	"gopkg.in/yaml.v2"
 )
@@ -19,7 +19,7 @@ import (
 type m map[string]interface{}
 
 func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
-	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
+	_, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
 	caB, err := caCrt.MarshalPEM()
 	if err != nil {
 		panic(err)
@@ -79,7 +79,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n
 }
 
 func TestService(t *testing.T) {
-	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
 		"static_host_map": m{},
 		"lighthouse": m{

+ 73 - 88
ssh.go

@@ -77,9 +77,6 @@ func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
 // that callers may invoke to run the configured ssh server. On
 // failure, it returns nil, error.
 func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
-	//TODO conntrack list
-	//TODO print firewall rules or hash?
-
 	listen := c.GetString("sshd.listen", "")
 	if listen == "" {
 		return nil, fmt.Errorf("sshd.listen must be provided")
@@ -93,7 +90,6 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 		return nil, fmt.Errorf("sshd.listen can not use port 22")
 	}
 
-	//TODO: no good way to reload this right now
 	hostKeyPathOrKey := c.GetString("sshd.host_key", "")
 	if hostKeyPathOrKey == "" {
 		return nil, fmt.Errorf("sshd.host_key must be provided")
@@ -320,7 +316,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-cert",
-		ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip",
+		ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintCertFlags{}
@@ -336,7 +332,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "print-tunnel",
-		ShortDescription: "Prints json details about a tunnel for the provided vpn ip",
+		ShortDescription: "Prints json details about a tunnel for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshPrintTunnelFlags{}
@@ -364,7 +360,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "change-remote",
-		ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip",
+		ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshChangeRemoteFlags{}
@@ -378,7 +374,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "close-tunnel",
-		ShortDescription: "Closes a tunnel for the provided vpn ip",
+		ShortDescription: "Closes a tunnel for the provided vpn addr",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
 			s := sshCloseTunnelFlags{}
@@ -392,7 +388,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "create-tunnel",
-		ShortDescription: "Creates a tunnel for the provided vpn ip and address",
+		ShortDescription: "Creates a tunnel for the provided vpn address",
 		Help:             "The lighthouses will be queried for real addresses but you can provide one as well.",
 		Flags: func() (*flag.FlagSet, interface{}) {
 			fl := flag.NewFlagSet("", flag.ContinueOnError)
@@ -407,8 +403,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "query-lighthouse",
-		ShortDescription: "Query the lighthouses for the provided vpn ip",
-		Help:             "This command is asynchronous. Only currently known udp ips will be printed.",
+		ShortDescription: "Query the lighthouses for the provided vpn address",
+		Help:             "This command is asynchronous. Only currently known udp addresses will be printed.",
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
 			return sshQueryLighthouse(f, fs, a, w)
 		},
@@ -418,7 +414,6 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter
 func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
 	fs, ok := a.(*sshListHostMapFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
@@ -430,7 +425,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 	}
 
 	sort.Slice(hm, func(i, j int) bool {
-		return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
+		return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0
 	})
 
 	if fs.Json || fs.Pretty {
@@ -441,13 +436,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 
 		err := js.Encode(hm)
 		if err != nil {
-			//TODO
 			return nil
 		}
 
 	} else {
 		for _, v := range hm {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
+			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs))
 			if err != nil {
 				return err
 			}
@@ -460,13 +454,12 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
 	fs, ok := a.(*sshListHostMapFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
 	type lighthouseInfo struct {
-		VpnIp string    `json:"vpnIp"`
-		Addrs *CacheMap `json:"addrs"`
+		VpnAddr string    `json:"vpnAddr"`
+		Addrs   *CacheMap `json:"addrs"`
 	}
 
 	lightHouse.RLock()
@@ -474,15 +467,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 	x := 0
 	for k, v := range lightHouse.addrMap {
 		addrMap[x] = lighthouseInfo{
-			VpnIp: k.String(),
-			Addrs: v.CopyCache(),
+			VpnAddr: k.String(),
+			Addrs:   v.CopyCache(),
 		}
 		x++
 	}
 	lightHouse.RUnlock()
 
 	sort.Slice(addrMap, func(i, j int) bool {
-		return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
+		return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0
 	})
 
 	if fs.Json || fs.Pretty {
@@ -493,7 +486,6 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 
 		err := js.Encode(addrMap)
 		if err != nil {
-			//TODO
 			return nil
 		}
 
@@ -503,7 +495,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 			if err != nil {
 				return err
 			}
-			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
+			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b)))
 			if err != nil {
 				return err
 			}
@@ -541,20 +533,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter
 
 func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
 	var cm *CacheMap
-	rl := ifce.lightHouse.Query(vpnIp)
+	rl := ifce.lightHouse.Query(vpnAddr)
 	if rl != nil {
 		cm = rl.CopyCache()
 	}
@@ -564,26 +556,25 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	flags, ok := fs.(*sshCloseTunnelFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo == nil {
-		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
 	}
 
 	if !flags.LocalOnly {
@@ -605,29 +596,28 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	flags, ok := fs.(*sshCreateTunnelFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 
-	hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
+	hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
@@ -640,7 +630,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		}
 	}
 
-	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
+	hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil)
 	if addr.IsValid() {
 		hostInfo.SetRemote(addr)
 	}
@@ -651,12 +641,11 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	flags, ok := fs.(*sshChangeRemoteFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
 	if flags.Address == "" {
@@ -668,18 +657,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("Address could not be parsed")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo == nil {
-		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0]))
 	}
 
 	hostInfo.SetRemote(addr)
@@ -781,24 +770,23 @@ func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWri
 func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	args, ok := fs.(*sshPrintCertFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
-	cert := ifce.pki.GetCertState().Certificate
+	cert := ifce.pki.getCertState().GetDefaultCertificate()
 	if len(a) > 0 {
-		vpnIp, err := netip.ParseAddr(a[0])
+		vpnAddr, err := netip.ParseAddr(a[0])
 		if err != nil {
-			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+			return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 		}
 
-		if !vpnIp.IsValid() {
-			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		if !vpnAddr.IsValid() {
+			return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 		}
 
-		hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+		hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 		if hostInfo == nil {
-			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
 		}
 
 		cert = hostInfo.GetCert().Certificate
@@ -807,7 +795,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 	if args.Json || args.Pretty {
 		b, err := cert.MarshalJSON()
 		if err != nil {
-			//TODO: handle it
 			return nil
 		}
 
@@ -816,7 +803,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 			err := json.Indent(buf, b, "", "    ")
 			b = buf.Bytes()
 			if err != nil {
-				//TODO: handle it
 				return nil
 			}
 		}
@@ -827,7 +813,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 	if args.Raw {
 		b, err := cert.MarshalPEM()
 		if err != nil {
-			//TODO: handle it
 			return nil
 		}
 
@@ -840,7 +825,6 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	args, ok := fs.(*sshPrintTunnelFlags)
 	if !ok {
-		//TODO: error
 		w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
 		return nil
 	}
@@ -856,15 +840,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		Error          error
 		Type           string
 		State          string
-		PeerIp         netip.Addr
+		PeerAddr       netip.Addr
 		LocalIndex     uint32
 		RemoteIndex    uint32
 		RelayedThrough []netip.Addr
 	}
 
 	type RelayOutput struct {
-		NebulaIp    netip.Addr
-		RelayForIps []RelayFor
+		NebulaAddr    netip.Addr
+		RelayForAddrs []RelayFor
 	}
 
 	type CmdOutput struct {
@@ -880,16 +864,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	}
 
 	for k, v := range relays {
-		ro := RelayOutput{NebulaIp: v.vpnIp}
+		ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]}
 		co.Relays = append(co.Relays, &ro)
-		relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
+		relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
 		if relayHI == nil {
-			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
+			ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")})
 			continue
 		}
-		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
+		for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() {
 			rf := RelayFor{Error: nil}
-			r, ok := relayHI.relayState.GetRelayForByIp(vpnIp)
+			r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr)
 			if ok {
 				t := ""
 				switch r.Type {
@@ -913,19 +897,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 
 				rf.LocalIndex = r.LocalIndex
 				rf.RemoteIndex = r.RemoteIndex
-				rf.PeerIp = r.PeerIp
+				rf.PeerAddr = r.PeerAddr
 				rf.Type = t
 				rf.State = s
 				if rf.LocalIndex != k {
 					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
 				}
 			}
-			relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
+			relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr)
 			if relayedHI != nil {
 				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
 			}
 
-			ro.RelayForIps = append(ro.RelayForIps, rf)
+			ro.RelayForAddrs = append(ro.RelayForAddrs, rf)
 		}
 	}
 	err := enc.Encode(co)
@@ -938,26 +922,25 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 	args, ok := fs.(*sshPrintTunnelFlags)
 	if !ok {
-		//TODO: error
 		return nil
 	}
 
 	if len(a) == 0 {
-		return w.WriteLine("No vpn ip was provided")
+		return w.WriteLine("No vpn address was provided")
 	}
 
-	vpnIp, err := netip.ParseAddr(a[0])
+	vpnAddr, err := netip.ParseAddr(a[0])
 	if err != nil {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+		return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 	}
 
-	if !vpnIp.IsValid() {
-		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
+	if !vpnAddr.IsValid() {
+		return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0]))
 	}
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr)
 	if hostInfo == nil {
-		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
+		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0]))
 	}
 
 	enc := json.NewEncoder(w.GetWriter())
@@ -971,13 +954,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
 
 	data := struct {
-		Name string `json:"name"`
-		Cidr string `json:"cidr"`
+		Name string         `json:"name"`
+		Cidr []netip.Prefix `json:"cidr"`
 	}{
 		Name: ifce.inside.Name(),
-		Cidr: ifce.inside.Cidr().String(),
+		Cidr: make([]netip.Prefix, len(ifce.inside.Networks())),
 	}
 
+	copy(data.Cidr, ifce.inside.Networks())
+
 	flags, ok := fs.(*sshDeviceInfoFlags)
 	if !ok {
 		return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)

+ 1 - 7
sshd/command.go

@@ -57,7 +57,6 @@ func execCommand(c *Command, args []string, w StringWriter) error {
 func dumpCommands(c *radix.Tree, w StringWriter) {
 	err := w.WriteLine("Available commands:")
 	if err != nil {
-		//TODO: log
 		return
 	}
 
@@ -67,10 +66,7 @@ func dumpCommands(c *radix.Tree, w StringWriter) {
 	}
 
 	sort.Strings(cmds)
-	err = w.Write(strings.Join(cmds, "\n") + "\n\n")
-	if err != nil {
-		//TODO: log
-	}
+	_ = w.Write(strings.Join(cmds, "\n") + "\n\n")
 }
 
 func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
@@ -119,8 +115,6 @@ func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error)
 	// We are printing a specific commands help text
 	cmd, err := lookupCommand(commands, a[0])
 	if err != nil {
-		//TODO: handle error
-		//TODO: message the user
 		return
 	}
 

+ 1 - 3
sshd/server.go

@@ -80,9 +80,7 @@ func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
 
 	s.config = &ssh.ServerConfig{
 		PublicKeyCallback: cc.Authenticate,
-		//TODO: AuthLogCallback: s.authAttempt,
-		//TODO: version string
-		ServerVersion: fmt.Sprintf("SSH-2.0-Nebula???"),
+		ServerVersion:     fmt.Sprintf("SSH-2.0-Nebula???"),
 	}
 
 	s.RegisterCommand(&Command{

+ 1 - 12
sshd/session.go

@@ -62,7 +62,6 @@ func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
 func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
 	for req := range in {
 		var err error
-		//TODO: maybe support window sizing?
 		switch req.Type {
 		case "shell":
 			if s.term == nil {
@@ -89,9 +88,7 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
 			req.Reply(true, nil)
 			s.dispatchCommand(payload.Value, &stringWriter{channel})
 
-			//TODO: Fix error handling and report the proper status back
 			status := struct{ Status uint32 }{uint32(0)}
-			//TODO: I think this is how we shut down a shell as well?
 			channel.SendRequest("exit-status", false, ssh.Marshal(status))
 			channel.Close()
 			return
@@ -110,7 +107,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
 }
 
 func (s *session) createTerm(channel ssh.Channel) *terminal.Terminal {
-	//TODO: PS1 with nebula cert name
 	term := terminal.NewTerminal(channel, s.c.User()+"@nebula > ")
 	term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
 		// key 9 is tab
@@ -137,7 +133,6 @@ func (s *session) handleInput(channel ssh.Channel) {
 	for {
 		line, err := s.term.ReadLine()
 		if err != nil {
-			//TODO: log
 			break
 		}
 
@@ -148,7 +143,6 @@ func (s *session) handleInput(channel ssh.Channel) {
 func (s *session) dispatchCommand(line string, w StringWriter) {
 	args, err := shlex.Split(line, true)
 	if err != nil {
-		//todo: LOG IT
 		return
 	}
 
@@ -159,13 +153,11 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
 
 	c, err := lookupCommand(s.commands, args[0])
 	if err != nil {
-		//TODO: handle the error
 		return
 	}
 
 	if c == nil {
 		err := w.WriteLine(fmt.Sprintf("did not understand: %s", line))
-		//TODO: log error
 		_ = err
 
 		dumpCommands(s.commands, w)
@@ -177,10 +169,7 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
 		return
 	}
 
-	err = execCommand(c, args[1:], w)
-	if err != nil {
-		//TODO: log the error
-	}
+	_ = execCommand(c, args[1:], w)
 	return
 }
 

+ 2 - 2
test/tun.go

@@ -16,8 +16,8 @@ func (NoopTun) Activate() error {
 	return nil
 }
 
-func (NoopTun) Cidr() netip.Prefix {
-	return netip.Prefix{}
+func (NoopTun) Networks() []netip.Prefix {
+	return []netip.Prefix{}
 }
 
 func (NoopTun) Name() string {

+ 4 - 4
timeout_test.go

@@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
 	assert.Equal(t, 0, tw.current)
 
 	fps := []firewall.Packet{
-		{LocalIP: netip.MustParseAddr("0.0.0.1")},
-		{LocalIP: netip.MustParseAddr("0.0.0.2")},
-		{LocalIP: netip.MustParseAddr("0.0.0.3")},
-		{LocalIP: netip.MustParseAddr("0.0.0.4")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.1")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.2")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.3")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.4")},
 	}
 
 	tw.Add(fps[0], time.Second*1)

+ 3 - 12
udp/conn.go

@@ -4,28 +4,19 @@ import (
 	"net/netip"
 
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/header"
 )
 
 const MTU = 9001
 
 type EncReader func(
 	addr netip.AddrPort,
-	out []byte,
-	packet []byte,
-	header *header.H,
-	fwPacket *firewall.Packet,
-	lhh LightHouseHandlerFunc,
-	nb []byte,
-	q int,
-	localCache firewall.ConntrackCache,
+	payload []byte,
 )
 
 type Conn interface {
 	Rebind() error
 	LocalAddr() (netip.AddrPort, error)
-	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
+	ListenOut(r EncReader)
 	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
 	Close() error
@@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
 func (NoopConn) LocalAddr() (netip.AddrPort, error) {
 	return netip.AddrPort{}, nil
 }
-func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
+func (NoopConn) ListenOut(_ EncReader) {
 	return
 }
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {

+ 0 - 10
udp/temp.go

@@ -1,10 +0,0 @@
-package udp
-
-import (
-	"net/netip"
-)
-
-//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
-
-// TODO: IPV6-WORK this can likely be removed now
-type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

+ 3 - 19
udp/udp_generic.go

@@ -15,8 +15,6 @@ import (
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/header"
 )
 
 type GenericConn struct {
@@ -60,7 +58,7 @@ func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
 }
 
 func (u *GenericConn) ReloadConfig(c *config.C) {
-	// TODO
+
 }
 
 func NewUDPStatsEmitter(udpConns []Conn) func() {
@@ -72,12 +70,8 @@ type rawMessage struct {
 	Len uint32
 }
 
-func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
-	plaintext := make([]byte, MTU)
+func (u *GenericConn) ListenOut(r EncReader) {
 	buffer := make([]byte, MTU)
-	h := &header.H{}
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
 
 	for {
 		// Just read one packet at a time
@@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
 			return
 		}
 
-		r(
-			netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
-			plaintext[:0],
-			buffer[:n],
-			h,
-			fwPacket,
-			lhf,
-			nb,
-			q,
-			cache.Get(u.l),
-		)
+		r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
 	}
 }

Some files were not shown because too many files changed in this diff