Browse Source

Support for ipv6 in the overlay with v2 certificates

---------

Co-authored-by: Jack Doan <[email protected]>
Nate Brown 9 months ago
parent
commit
f2c32421c4
86 changed files with 5418 additions and 3239 deletions
  1. 1 1
      Makefile
  2. 34 25
      calculated_remote.go
  3. 61 5
      calculated_remote_test.go
  4. 15 4
      cert/README.md
  5. 52 0
      cert/asn1.go
  6. 10 10
      cert/ca_pool_test.go
  7. 49 14
      cert/cert.go
  8. 31 31
      cert/cert_test.go
  9. 161 235
      cert/cert_v1.go
  10. 37 0
      cert/cert_v2.asn1
  11. 621 0
      cert/cert_v2.go
  12. 3 0
      cert/errors.go
  13. 12 7
      cert/pem.go
  14. 97 12
      cert/sign.go
  15. 46 22
      cmd/nebula-cert/ca.go
  16. 20 14
      cmd/nebula-cert/ca_test.go
  17. 11 6
      cmd/nebula-cert/print.go
  18. 61 2
      cmd/nebula-cert/print_test.go
  19. 168 66
      cmd/nebula-cert/sign.go
  20. 47 34
      cmd/nebula-cert/sign_test.go
  21. 53 29
      connection_manager.go
  22. 26 29
      connection_manager_test.go
  23. 26 21
      connection_state.go
  24. 25 19
      control.go
  25. 16 16
      control_test.go
  26. 42 21
      control_tester.go
  27. 74 36
      dns_server.go
  28. 20 5
      dns_server_test.go
  29. 208 171
      e2e/handshakes_test.go
  30. 2 2
      e2e/helpers.go
  31. 73 21
      e2e/helpers_test.go
  32. 4 3
      e2e/router/hostmap.go
  33. 44 22
      e2e/router/router.go
  34. 13 4
      examples/config.yml
  35. 67 55
      firewall.go
  36. 11 10
      firewall/packet.go
  37. 40 36
      firewall_test.go
  38. 0 1
      go.mod
  39. 0 2
      go.sum
  40. 162 80
      handshake_ix.go
  41. 135 87
      handshake_manager.go
  42. 19 11
      handshake_manager_test.go
  43. 95 86
      hostmap.go
  44. 23 33
      hostmap_test.go
  45. 2 2
      hostmap_tester.go
  46. 32 26
      inside.go
  47. 71 60
      interface.go
  48. 490 298
      lighthouse.go
  49. 93 66
      lighthouse_test.go
  50. 5 19
      main.go
  51. 467 171
      nebula.pb.go
  52. 23 9
      nebula.proto
  53. 135 82
      outside.go
  54. 71 10
      outside_test.go
  55. 1 1
      overlay/device.go
  56. 22 12
      overlay/route.go
  57. 39 33
      overlay/route_test.go
  58. 9 9
      overlay/tun.go
  59. 11 11
      overlay/tun_android.go
  60. 207 214
      overlay/tun_darwin.go
  61. 8 8
      overlay/tun_disabled.go
  62. 25 15
      overlay/tun_freebsd.go
  63. 10 10
      overlay/tun_ios.go
  64. 108 69
      overlay/tun_linux.go
  65. 28 17
      overlay/tun_netbsd.go
  66. 29 19
      overlay/tun_openbsd.go
  67. 17 17
      overlay/tun_tester.go
  68. 0 208
      overlay/tun_water_windows.go
  69. 239 13
      overlay/tun_windows.go
  70. 0 252
      overlay/tun_wintun_windows.go
  71. 6 6
      overlay/user.go
  72. 338 67
      pki.go
  73. 122 63
      relay_manager.go
  74. 34 28
      remote_list.go
  75. 20 20
      remote_list_test.go
  76. 2 2
      service/service.go
  77. 1 1
      service/service_test.go
  78. 20 17
      ssh.go
  79. 2 2
      test/tun.go
  80. 4 4
      timeout_test.go
  81. 3 12
      udp/conn.go
  82. 0 10
      udp/temp.go
  83. 2 18
      udp/udp_generic.go
  84. 3 23
      udp/udp_linux.go
  85. 2 19
      udp/udp_rio_windows.go
  86. 2 8
      udp/udp_tester.go

+ 1 - 1
Makefile

@@ -196,7 +196,7 @@ bench-cpu-long:
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 	go tool pprof go-audit.test cpu.pprof
 
 
-proto: nebula.pb.go cert/cert.pb.go
+proto: nebula.pb.go cert/cert_v1.pb.go
 
 
 nebula.pb.go: nebula.proto .FORCE
 nebula.pb.go: nebula.proto .FORCE
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster
 	go build github.com/gogo/protobuf/protoc-gen-gogofaster

+ 34 - 25
calculated_remote.go

@@ -21,7 +21,11 @@ type calculatedRemote struct {
 	port  uint32
 	port  uint32
 }
 }
 
 
-func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() {
+		return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr)
+	}
+
 	masked := maskCidr.Masked()
 	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
 		return nil, fmt.Errorf("invalid port: %d", port)
@@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 }
 
 
-func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
-	// Combine the masked bytes of the "mask" IP with the unmasked bytes
-	// of the overlay IP
-	if c.ipNet.Addr().Is4() {
-		return c.apply4(ip)
-	}
-	return c.apply6(ip)
-}
-
-func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK this can be less crappy
+func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort {
+	// Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP
 	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
 	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
 	mask := binary.BigEndian.Uint32(maskb[:])
 	mask := binary.BigEndian.Uint32(maskb[:])
 
 
 	b := c.mask.Addr().As4()
 	b := c.mask.Addr().As4()
-	maskIp := binary.BigEndian.Uint32(b[:])
+	maskAddr := binary.BigEndian.Uint32(b[:])
 
 
-	b = ip.As4()
-	intIp := binary.BigEndian.Uint32(b[:])
+	b = addr.As4()
+	intAddr := binary.BigEndian.Uint32(b[:])
 
 
-	return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+	return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port}
 }
 }
 
 
-func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
-	//TODO: IPV6-WORK
-	panic("Can not calculate ipv6 remote addresses")
+func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort {
+	mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	maskAddr := c.mask.Addr().As16()
+	calcAddr := addr.As16()
+
+	ap := V6AddrPort{Port: c.port}
+
+	maskb := binary.BigEndian.Uint64(mask[:8])
+	maskAddrb := binary.BigEndian.Uint64(maskAddr[:8])
+	calcAddrb := binary.BigEndian.Uint64(calcAddr[:8])
+	ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	maskb = binary.BigEndian.Uint64(mask[8:])
+	maskAddrb = binary.BigEndian.Uint64(maskAddr[8:])
+	calcAddrb = binary.BigEndian.Uint64(calcAddr[8:])
+	ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb)
+
+	return &ap
 }
 }
 
 
 func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
 func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
@@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 		}
 
 
-		//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
-		entry, err := newCalculatedRemotesListFromConfig(rawValue)
+		entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
 		}
@@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu
 	return calculatedRemotes, nil
 	return calculatedRemotes, nil
 }
 }
 
 
-func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
+func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) {
 	rawList, ok := raw.([]any)
 	rawList, ok := raw.([]any)
 	if !ok {
 	if !ok {
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
 		return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw)
@@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 
 
 	var l []*calculatedRemote
 	var l []*calculatedRemote
 	for _, e := range rawList {
 	for _, e := range rawList {
-		c, err := newCalculatedRemotesEntryFromConfig(e)
+		c, err := newCalculatedRemotesEntryFromConfig(cidr, e)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 			return nil, fmt.Errorf("calculated_remotes entry: %w", err)
 		}
 		}
@@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) {
 	return l, nil
 	return l, nil
 }
 }
 
 
-func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
+func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) {
 	rawMap, ok := raw.(map[any]any)
 	rawMap, ok := raw.(map[any]any)
 	if !ok {
 	if !ok {
 		return nil, fmt.Errorf("invalid type: %T", raw)
 		return nil, fmt.Errorf("invalid type: %T", raw)
@@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 	}
 
 
-	return newCalculatedRemote(maskCidr, port)
+	return newCalculatedRemote(cidr, maskCidr, port)
 }
 }

+ 61 - 5
calculated_remote_test.go

@@ -9,10 +9,9 @@ import (
 )
 )
 
 
 func TestCalculatedRemoteApply(t *testing.T) {
 func TestCalculatedRemoteApply(t *testing.T) {
-	ipNet, err := netip.ParsePrefix("192.168.1.0/24")
-	require.NoError(t, err)
-
-	c, err := newCalculatedRemote(ipNet, 4242)
+	// Test v4 addresses
+	ipNet := netip.MustParsePrefix("192.168.1.0/24")
+	c, err := newCalculatedRemote(ipNet, ipNet, 4242)
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	input, err := netip.ParseAddr("10.0.10.182")
 	input, err := netip.ParseAddr("10.0.10.182")
@@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) {
 	expected, err := netip.ParseAddr("192.168.1.182")
 	expected, err := netip.ParseAddr("192.168.1.182")
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
+	assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
+
+	// Test v6 addresses
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+
+	// Test v6 addresses part 2
+	ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48")
+	c, err = newCalculatedRemote(ipNet, ipNet, 4242)
+	require.NoError(t, err)
+
+	input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
+	assert.NoError(t, err)
+
+	assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
+}
+
+func Test_newCalculatedRemote(t *testing.T) {
+	c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32")
+	require.Nil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
+
+	c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242)
+	require.NoError(t, err)
+	require.NotNil(t, c)
 }
 }

+ 15 - 4
cert/README.md

@@ -2,14 +2,25 @@
 
 
 This is a library for interacting with `nebula` style certificates and authorities.
 This is a library for interacting with `nebula` style certificates and authorities.
 
 
-A `protobuf` definition of the certificate format is also included
+There are now 2 versions of `nebula` certificates:
 
 
-### Compiling the protobuf definition
+## v1
 
 
-Make sure you have `protoc` installed.
+This version is deprecated.
+
+A `protobuf` definition of the certificate format is included at `cert_v1.proto`
+
+To compile the definition you will need `protoc` installed.
 
 
 To compile for `go` with the same version of protobuf specified in go.mod:
 To compile for `go` with the same version of protobuf specified in go.mod:
 
 
 ```bash
 ```bash
-make
+make proto
 ```
 ```
+
+## v2
+
+This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate
+future certificate changes better than v1.
+
+`cert_v2.asn1` defines the wire format and can be used to compile marshalers.

+ 52 - 0
cert/asn1.go

@@ -0,0 +1,52 @@
+package cert
+
+import (
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+)
+
+// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value
+// https://github.com/golang/go/issues/64811#issuecomment-1944446920
+func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0] > 0
+		return true
+	}
+
+	return false
+}
+
+// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value
+// Similar issue as with readOptionalASN1Boolean
+func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool {
+	var present bool
+	var child cryptobyte.String
+	if !b.ReadOptionalASN1(&child, &present, tag) {
+		return false
+	}
+
+	if !present {
+		*out = defaultValue
+		return true
+	}
+
+	// Ensure we have 1 byte
+	if len(child) == 1 {
+		*out = child[0]
+		return true
+	}
+
+	return false
+}

+ 10 - 10
cert/ca_pool_test.go

@@ -63,31 +63,31 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
 
 
 	rootCA := certificateV1{
 	rootCA := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name: "nebula root ca",
+			name: "nebula root ca",
 		},
 		},
 	}
 	}
 
 
 	rootCA01 := certificateV1{
 	rootCA01 := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name: "nebula root ca 01",
+			name: "nebula root ca 01",
 		},
 		},
 	}
 	}
 
 
 	rootCAP256 := certificateV1{
 	rootCAP256 := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name: "nebula P256 test",
+			name: "nebula P256 test",
 		},
 		},
 	}
 	}
 
 
 	p, err := NewCAPoolFromPEM([]byte(noNewLines))
 	p, err := NewCAPoolFromPEM([]byte(noNewLines))
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
-	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
+	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
 
 
 	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
 	pp, err := NewCAPoolFromPEM([]byte(withNewLines))
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
-	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
+	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
 
 
 	// expired cert, no valid certs
 	// expired cert, no valid certs
 	ppp, err := NewCAPoolFromPEM([]byte(expired))
 	ppp, err := NewCAPoolFromPEM([]byte(expired))
@@ -97,13 +97,13 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
 	// expired cert, with valid certs
 	// expired cert, with valid certs
 	pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
 	pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...))
 	assert.Equal(t, ErrExpired, err)
 	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name)
-	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name)
+	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name)
+	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name)
 	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
 	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired")
 	assert.Equal(t, len(pppp.CAs), 3)
 	assert.Equal(t, len(pppp.CAs), 3)
 
 
 	ppppp, err := NewCAPoolFromPEM([]byte(p256))
 	ppppp, err := NewCAPoolFromPEM([]byte(p256))
 	assert.Nil(t, err)
 	assert.Nil(t, err)
-	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name)
+	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.name)
 	assert.Equal(t, len(ppppp.CAs), 1)
 	assert.Equal(t, len(ppppp.CAs), 1)
 }
 }

+ 49 - 14
cert/cert.go

@@ -1,15 +1,17 @@
 package cert
 package cert
 
 
 import (
 import (
+	"fmt"
 	"net/netip"
 	"net/netip"
 	"time"
 	"time"
 )
 )
 
 
-type Version int
+type Version uint8
 
 
 const (
 const (
-	Version1 Version = 1
-	Version2 Version = 2
+	VersionPre1 Version = 0
+	Version1    Version = 1
+	Version2    Version = 2
 )
 )
 
 
 type Certificate interface {
 type Certificate interface {
@@ -107,23 +109,56 @@ type CachedCertificate struct {
 	signerFingerprint string
 	signerFingerprint string
 }
 }
 
 
-// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate.
-func UnmarshalCertificate(b []byte) (Certificate, error) {
-	c, err := unmarshalCertificateV1(b, true)
-	if err != nil {
-		return nil, err
-	}
-	return c, nil
+func (cc *CachedCertificate) String() string {
+	return cc.Certificate.String()
 }
 }
 
 
-// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
+// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake.
 // Handshakes save space by placing the peers public key in a different part of the packet, we have to
 // Handshakes save space by placing the peers public key in a different part of the packet, we have to
 // reassemble the actual certificate structure with that in mind.
 // reassemble the actual certificate structure with that in mind.
-func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) {
-	c, err := unmarshalCertificateV1(b, false)
+func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
+	if publicKey == nil {
+		return nil, ErrNoPeerStaticKey
+	}
+
+	if rawCertBytes == nil {
+		return nil, ErrNoPayload
+	}
+
+	c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
+	if err != nil {
+		return nil, fmt.Errorf("error unmarshaling cert: %w", err)
+	}
+
+	cc, err := caPool.VerifyCertificate(time.Now(), c)
+	if err != nil {
+		return nil, fmt.Errorf("certificate validation failed: %w", err)
+	}
+
+	return cc, nil
+}
+
+func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
+	var c Certificate
+	var err error
+
+	switch v {
+	case VersionPre1, Version1:
+		c, err = unmarshalCertificateV1(b, publicKey)
+	case Version2:
+		c, err = unmarshalCertificateV2(b, publicKey, curve)
+	default:
+		//TODO: make a static var
+		return nil, fmt.Errorf("unknown certificate version %d", v)
+	}
+
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	c.details.PublicKey = publicKey
+
+	if c.Curve() != curve {
+		return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
+	}
+
 	return c, nil
 	return c, nil
 }
 }

+ 31 - 31
cert/cert_test.go

@@ -24,21 +24,21 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
 
 
 	nc := certificateV1{
 	nc := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name: "testing",
-			Ips: []netip.Prefix{
+			name: "testing",
+			networks: []netip.Prefix{
 				mustParsePrefixUnmapped("10.1.1.1/24"),
 				mustParsePrefixUnmapped("10.1.1.1/24"),
 				mustParsePrefixUnmapped("10.1.1.2/16"),
 				mustParsePrefixUnmapped("10.1.1.2/16"),
 			},
 			},
-			Subnets: []netip.Prefix{
+			unsafeNetworks: []netip.Prefix{
 				mustParsePrefixUnmapped("9.1.1.2/24"),
 				mustParsePrefixUnmapped("9.1.1.2/24"),
 				mustParsePrefixUnmapped("9.1.1.3/16"),
 				mustParsePrefixUnmapped("9.1.1.3/16"),
 			},
 			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: before,
-			NotAfter:  after,
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: before,
+			notAfter:  after,
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
 		},
 		},
 		signature: []byte("1234567890abcedfghij1234567890ab"),
 		signature: []byte("1234567890abcedfghij1234567890ab"),
 	}
 	}
@@ -47,20 +47,20 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	//t.Log("Cert size:", len(b))
 	//t.Log("Cert size:", len(b))
 
 
-	nc2, err := unmarshalCertificateV1(b, true)
+	nc2, err := unmarshalCertificateV1(b, nil)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	assert.Equal(t, nc.signature, nc2.Signature())
 	assert.Equal(t, nc.signature, nc2.Signature())
-	assert.Equal(t, nc.details.Name, nc2.Name())
-	assert.Equal(t, nc.details.NotBefore, nc2.NotBefore())
-	assert.Equal(t, nc.details.NotAfter, nc2.NotAfter())
-	assert.Equal(t, nc.details.PublicKey, nc2.PublicKey())
-	assert.Equal(t, nc.details.IsCA, nc2.IsCA())
+	assert.Equal(t, nc.details.name, nc2.Name())
+	assert.Equal(t, nc.details.notBefore, nc2.NotBefore())
+	assert.Equal(t, nc.details.notAfter, nc2.NotAfter())
+	assert.Equal(t, nc.details.publicKey, nc2.PublicKey())
+	assert.Equal(t, nc.details.isCA, nc2.IsCA())
 
 
-	assert.Equal(t, nc.details.Ips, nc2.Networks())
-	assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks())
+	assert.Equal(t, nc.details.networks, nc2.Networks())
+	assert.Equal(t, nc.details.unsafeNetworks, nc2.UnsafeNetworks())
 
 
-	assert.Equal(t, nc.details.Groups, nc2.Groups())
+	assert.Equal(t, nc.details.groups, nc2.Groups())
 }
 }
 
 
 //func TestNebulaCertificate_Sign(t *testing.T) {
 //func TestNebulaCertificate_Sign(t *testing.T) {
@@ -150,8 +150,8 @@ func TestMarshalingNebulaCertificate(t *testing.T) {
 func TestNebulaCertificate_Expired(t *testing.T) {
 func TestNebulaCertificate_Expired(t *testing.T) {
 	nc := certificateV1{
 	nc := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			NotBefore: time.Now().Add(time.Second * -60).Round(time.Second),
-			NotAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
+			notBefore: time.Now().Add(time.Second * -60).Round(time.Second),
+			notAfter:  time.Now().Add(time.Second * 60).Round(time.Second),
 		},
 		},
 	}
 	}
 
 
@@ -166,21 +166,21 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
 
 
 	nc := certificateV1{
 	nc := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name: "testing",
-			Ips: []netip.Prefix{
+			name: "testing",
+			networks: []netip.Prefix{
 				mustParsePrefixUnmapped("10.1.1.1/24"),
 				mustParsePrefixUnmapped("10.1.1.1/24"),
 				mustParsePrefixUnmapped("10.1.1.2/16"),
 				mustParsePrefixUnmapped("10.1.1.2/16"),
 			},
 			},
-			Subnets: []netip.Prefix{
+			unsafeNetworks: []netip.Prefix{
 				mustParsePrefixUnmapped("9.1.1.2/24"),
 				mustParsePrefixUnmapped("9.1.1.2/24"),
 				mustParsePrefixUnmapped("9.1.1.3/16"),
 				mustParsePrefixUnmapped("9.1.1.3/16"),
 			},
 			},
-			Groups:    []string{"test-group1", "test-group2", "test-group3"},
-			NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
-			NotAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
-			PublicKey: pubKey,
-			IsCA:      false,
-			Issuer:    "1234567890abcedfghij1234567890ab",
+			groups:    []string{"test-group1", "test-group2", "test-group3"},
+			notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC),
+			notAfter:  time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC),
+			publicKey: pubKey,
+			isCA:      false,
+			issuer:    "1234567890abcedfghij1234567890ab",
 		},
 		},
 		signature: []byte("1234567890abcedfghij1234567890ab"),
 		signature: []byte("1234567890abcedfghij1234567890ab"),
 	}
 	}
@@ -189,7 +189,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(
 	assert.Equal(
 		t,
 		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}",
+		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
 		string(b),
 		string(b),
 	)
 	)
 }
 }
@@ -526,7 +526,7 @@ func TestNebulaCertificate_Copy(t *testing.T) {
 func TestUnmarshalNebulaCertificate(t *testing.T) {
 func TestUnmarshalNebulaCertificate(t *testing.T) {
 	// Test that we don't panic with an invalid certificate (#332)
 	// Test that we don't panic with an invalid certificate (#332)
 	data := []byte("\x98\x00\x00")
 	data := []byte("\x98\x00\x00")
-	_, err := unmarshalCertificateV1(data, true)
+	_, err := unmarshalCertificateV1(data, nil)
 	assert.EqualError(t, err, "encoded Details was nil")
 	assert.EqualError(t, err, "encoded Details was nil")
 }
 }
 
 

+ 161 - 235
cert/cert_v1.go

@@ -6,19 +6,16 @@ import (
 	"crypto/ecdsa"
 	"crypto/ecdsa"
 	"crypto/ed25519"
 	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/elliptic"
-	"crypto/rand"
 	"crypto/sha256"
 	"crypto/sha256"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/hex"
 	"encoding/hex"
 	"encoding/json"
 	"encoding/json"
 	"encoding/pem"
 	"encoding/pem"
 	"fmt"
 	"fmt"
-	"math/big"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
 	"time"
 	"time"
 
 
-	"github.com/slackhq/nebula/pkclient"
 	"golang.org/x/crypto/curve25519"
 	"golang.org/x/crypto/curve25519"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/proto"
 )
 )
@@ -31,71 +28,71 @@ type certificateV1 struct {
 }
 }
 
 
 type detailsV1 struct {
 type detailsV1 struct {
-	Name      string
-	Ips       []netip.Prefix
-	Subnets   []netip.Prefix
-	Groups    []string
-	NotBefore time.Time
-	NotAfter  time.Time
-	PublicKey []byte
-	IsCA      bool
-	Issuer    string
-
-	Curve Curve
+	name           string
+	networks       []netip.Prefix
+	unsafeNetworks []netip.Prefix
+	groups         []string
+	notBefore      time.Time
+	notAfter       time.Time
+	publicKey      []byte
+	isCA           bool
+	issuer         string
+
+	curve Curve
 }
 }
 
 
 type m map[string]interface{}
 type m map[string]interface{}
 
 
-func (nc *certificateV1) Version() Version {
+func (c *certificateV1) Version() Version {
 	return Version1
 	return Version1
 }
 }
 
 
-func (nc *certificateV1) Curve() Curve {
-	return nc.details.Curve
+func (c *certificateV1) Curve() Curve {
+	return c.details.curve
 }
 }
 
 
-func (nc *certificateV1) Groups() []string {
-	return nc.details.Groups
+func (c *certificateV1) Groups() []string {
+	return c.details.groups
 }
 }
 
 
-func (nc *certificateV1) IsCA() bool {
-	return nc.details.IsCA
+func (c *certificateV1) IsCA() bool {
+	return c.details.isCA
 }
 }
 
 
-func (nc *certificateV1) Issuer() string {
-	return nc.details.Issuer
+func (c *certificateV1) Issuer() string {
+	return c.details.issuer
 }
 }
 
 
-func (nc *certificateV1) Name() string {
-	return nc.details.Name
+func (c *certificateV1) Name() string {
+	return c.details.name
 }
 }
 
 
-func (nc *certificateV1) Networks() []netip.Prefix {
-	return nc.details.Ips
+func (c *certificateV1) Networks() []netip.Prefix {
+	return c.details.networks
 }
 }
 
 
-func (nc *certificateV1) NotAfter() time.Time {
-	return nc.details.NotAfter
+func (c *certificateV1) NotAfter() time.Time {
+	return c.details.notAfter
 }
 }
 
 
-func (nc *certificateV1) NotBefore() time.Time {
-	return nc.details.NotBefore
+func (c *certificateV1) NotBefore() time.Time {
+	return c.details.notBefore
 }
 }
 
 
-func (nc *certificateV1) PublicKey() []byte {
-	return nc.details.PublicKey
+func (c *certificateV1) PublicKey() []byte {
+	return c.details.publicKey
 }
 }
 
 
-func (nc *certificateV1) Signature() []byte {
-	return nc.signature
+func (c *certificateV1) Signature() []byte {
+	return c.signature
 }
 }
 
 
-func (nc *certificateV1) UnsafeNetworks() []netip.Prefix {
-	return nc.details.Subnets
+func (c *certificateV1) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
 }
 }
 
 
-func (nc *certificateV1) Fingerprint() (string, error) {
-	b, err := nc.Marshal()
+func (c *certificateV1) Fingerprint() (string, error) {
+	b, err := c.Marshal()
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
@@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) {
 	return hex.EncodeToString(sum[:]), nil
 	return hex.EncodeToString(sum[:]), nil
 }
 }
 
 
-func (nc *certificateV1) CheckSignature(key []byte) bool {
-	b, err := proto.Marshal(nc.getRawDetails())
+func (c *certificateV1) CheckSignature(key []byte) bool {
+	b, err := proto.Marshal(c.getRawDetails())
 	if err != nil {
 	if err != nil {
 		return false
 		return false
 	}
 	}
-	switch nc.details.Curve {
+	switch c.details.curve {
 	case Curve_CURVE25519:
 	case Curve_CURVE25519:
-		return ed25519.Verify(key, b, nc.signature)
+		return ed25519.Verify(key, b, c.signature)
 	case Curve_P256:
 	case Curve_P256:
 		x, y := elliptic.Unmarshal(elliptic.P256(), key)
 		x, y := elliptic.Unmarshal(elliptic.P256(), key)
 		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
 		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
 		hashed := sha256.Sum256(b)
 		hashed := sha256.Sum256(b)
-		return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
 	default:
 	default:
 		return false
 		return false
 	}
 	}
 }
 }
 
 
-func (nc *certificateV1) Expired(t time.Time) bool {
-	return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t)
+func (c *certificateV1) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
 }
 }
 
 
-func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
-	if curve != nc.details.Curve {
+func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.details.curve {
 		return fmt.Errorf("curve in cert and private key supplied don't match")
 		return fmt.Errorf("curve in cert and private key supplied don't match")
 	}
 	}
-	if nc.details.IsCA {
+	if c.details.isCA {
 		switch curve {
 		switch curve {
 		case Curve_CURVE25519:
 		case Curve_CURVE25519:
 			// the call to PublicKey below will panic slice bounds out of range otherwise
 			// the call to PublicKey below will panic slice bounds out of range otherwise
@@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
 				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
 			}
 			}
 
 
-			if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) {
+			if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
 				return fmt.Errorf("public key in cert and private key supplied don't match")
 				return fmt.Errorf("public key in cert and private key supplied don't match")
 			}
 			}
 		case Curve_P256:
 		case Curve_P256:
@@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 				return fmt.Errorf("cannot parse private key as P256: %w", err)
 				return fmt.Errorf("cannot parse private key as P256: %w", err)
 			}
 			}
 			pub := privkey.PublicKey().Bytes()
 			pub := privkey.PublicKey().Bytes()
-			if !bytes.Equal(pub, nc.details.PublicKey) {
+			if !bytes.Equal(pub, c.details.publicKey) {
 				return fmt.Errorf("public key in cert and private key supplied don't match")
 				return fmt.Errorf("public key in cert and private key supplied don't match")
 			}
 			}
 		default:
 		default:
@@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 	default:
 	default:
 		return fmt.Errorf("invalid curve: %s", curve)
 		return fmt.Errorf("invalid curve: %s", curve)
 	}
 	}
-	if !bytes.Equal(pub, nc.details.PublicKey) {
+	if !bytes.Equal(pub, c.details.publicKey) {
 		return fmt.Errorf("public key in cert and private key supplied don't match")
 		return fmt.Errorf("public key in cert and private key supplied don't match")
 	}
 	}
 
 
@@ -181,173 +178,155 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error {
 }
 }
 
 
 // getRawDetails marshals the raw details into protobuf ready struct
 // getRawDetails marshals the raw details into protobuf ready struct
-func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
+func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails {
 	rd := &RawNebulaCertificateDetails{
 	rd := &RawNebulaCertificateDetails{
-		Name:      nc.details.Name,
-		Groups:    nc.details.Groups,
-		NotBefore: nc.details.NotBefore.Unix(),
-		NotAfter:  nc.details.NotAfter.Unix(),
-		PublicKey: make([]byte, len(nc.details.PublicKey)),
-		IsCA:      nc.details.IsCA,
-		Curve:     nc.details.Curve,
+		Name:      c.details.name,
+		Groups:    c.details.groups,
+		NotBefore: c.details.notBefore.Unix(),
+		NotAfter:  c.details.notAfter.Unix(),
+		PublicKey: make([]byte, len(c.details.publicKey)),
+		IsCA:      c.details.isCA,
+		Curve:     c.details.curve,
 	}
 	}
 
 
-	for _, ipNet := range nc.details.Ips {
+	for _, ipNet := range c.details.networks {
 		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
 		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
 		rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
 		rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask))
 	}
 	}
 
 
-	for _, ipNet := range nc.details.Subnets {
+	for _, ipNet := range c.details.unsafeNetworks {
 		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
 		mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen())
 		rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
 		rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask))
 	}
 	}
 
 
-	copy(rd.PublicKey, nc.details.PublicKey[:])
+	copy(rd.PublicKey, c.details.publicKey[:])
 
 
 	// I know, this is terrible
 	// I know, this is terrible
-	rd.Issuer, _ = hex.DecodeString(nc.details.Issuer)
+	rd.Issuer, _ = hex.DecodeString(c.details.issuer)
 
 
 	return rd
 	return rd
 }
 }
 
 
-func (nc *certificateV1) String() string {
-	if nc == nil {
-		return "Certificate {}\n"
-	}
-
-	s := "NebulaCertificate {\n"
-	s += "\tDetails {\n"
-	s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name)
-
-	if len(nc.details.Ips) > 0 {
-		s += "\t\tIps: [\n"
-		for _, ip := range nc.details.Ips {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tIps: []\n"
-	}
-
-	if len(nc.details.Subnets) > 0 {
-		s += "\t\tSubnets: [\n"
-		for _, ip := range nc.details.Subnets {
-			s += fmt.Sprintf("\t\t\t%v\n", ip.String())
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tSubnets: []\n"
-	}
-
-	if len(nc.details.Groups) > 0 {
-		s += "\t\tGroups: [\n"
-		for _, g := range nc.details.Groups {
-			s += fmt.Sprintf("\t\t\t\"%v\"\n", g)
-		}
-		s += "\t\t]\n"
-	} else {
-		s += "\t\tGroups: []\n"
-	}
-
-	s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore)
-	s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter)
-	s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA)
-	s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer)
-	s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey)
-	s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve)
-	s += "\t}\n"
-	fp, err := nc.Fingerprint()
-	if err == nil {
-		s += fmt.Sprintf("\tFingerprint: %s\n", fp)
+func (c *certificateV1) String() string {
+	b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
+	if err != nil {
+		return "<error marshalling certificate>"
 	}
 	}
-	s += fmt.Sprintf("\tSignature: %x\n", nc.Signature())
-	s += "}"
-
-	return s
+	return string(b)
 }
 }
 
 
-func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) {
-	pubKey := nc.details.PublicKey
-	nc.details.PublicKey = nil
-	rawCertNoKey, err := nc.Marshal()
+func (c *certificateV1) MarshalForHandshakes() ([]byte, error) {
+	pubKey := c.details.publicKey
+	c.details.publicKey = nil
+	rawCertNoKey, err := c.Marshal()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	nc.details.PublicKey = pubKey
+	c.details.publicKey = pubKey
 	return rawCertNoKey, nil
 	return rawCertNoKey, nil
 }
 }
 
 
-func (nc *certificateV1) Marshal() ([]byte, error) {
+func (c *certificateV1) Marshal() ([]byte, error) {
 	rc := RawNebulaCertificate{
 	rc := RawNebulaCertificate{
-		Details:   nc.getRawDetails(),
-		Signature: nc.signature,
+		Details:   c.getRawDetails(),
+		Signature: c.signature,
 	}
 	}
 
 
 	return proto.Marshal(&rc)
 	return proto.Marshal(&rc)
 }
 }
 
 
-func (nc *certificateV1) MarshalPEM() ([]byte, error) {
-	b, err := nc.Marshal()
+func (c *certificateV1) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
 	return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil
 }
 }
 
 
-func (nc *certificateV1) MarshalJSON() ([]byte, error) {
-	fp, _ := nc.Fingerprint()
-	jc := m{
+func (c *certificateV1) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.marshalJSON())
+}
+
+func (c *certificateV1) marshalJSON() m {
+	fp, _ := c.Fingerprint()
+	return m{
+		"version": Version1,
 		"details": m{
 		"details": m{
-			"name":      nc.details.Name,
-			"ips":       nc.details.Ips,
-			"subnets":   nc.details.Subnets,
-			"groups":    nc.details.Groups,
-			"notBefore": nc.details.NotBefore,
-			"notAfter":  nc.details.NotAfter,
-			"publicKey": fmt.Sprintf("%x", nc.details.PublicKey),
-			"isCa":      nc.details.IsCA,
-			"issuer":    nc.details.Issuer,
-			"curve":     nc.details.Curve.String(),
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"publicKey":      fmt.Sprintf("%x", c.details.publicKey),
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+			"curve":          c.details.curve.String(),
 		},
 		},
 		"fingerprint": fp,
 		"fingerprint": fp,
-		"signature":   fmt.Sprintf("%x", nc.Signature()),
+		"signature":   fmt.Sprintf("%x", c.Signature()),
 	}
 	}
-	return json.Marshal(jc)
 }
 }
 
 
-func (nc *certificateV1) Copy() Certificate {
-	c := &certificateV1{
+func (c *certificateV1) Copy() Certificate {
+	nc := &certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name:      nc.details.Name,
-			Groups:    make([]string, len(nc.details.Groups)),
-			Ips:       make([]netip.Prefix, len(nc.details.Ips)),
-			Subnets:   make([]netip.Prefix, len(nc.details.Subnets)),
-			NotBefore: nc.details.NotBefore,
-			NotAfter:  nc.details.NotAfter,
-			PublicKey: make([]byte, len(nc.details.PublicKey)),
-			IsCA:      nc.details.IsCA,
-			Issuer:    nc.details.Issuer,
+			name:           c.details.name,
+			groups:         make([]string, len(c.details.groups)),
+			networks:       make([]netip.Prefix, len(c.details.networks)),
+			unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
+			notBefore:      c.details.notBefore,
+			notAfter:       c.details.notAfter,
+			publicKey:      make([]byte, len(c.details.publicKey)),
+			isCA:           c.details.isCA,
+			issuer:         c.details.issuer,
+			curve:          c.details.curve,
 		},
 		},
-		signature: make([]byte, len(nc.signature)),
+		signature: make([]byte, len(c.signature)),
 	}
 	}
 
 
-	copy(c.signature, nc.signature)
-	copy(c.details.Groups, nc.details.Groups)
-	copy(c.details.PublicKey, nc.details.PublicKey)
+	copy(nc.signature, c.signature)
+	copy(nc.details.groups, c.details.groups)
+	copy(nc.details.publicKey, c.details.publicKey)
+	copy(nc.details.networks, c.details.networks)
+	copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
 
 
-	for i, p := range nc.details.Ips {
-		c.details.Ips[i] = p
+	return nc
+}
+
+func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV1{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		publicKey:      t.PublicKey,
+		isCA:           t.IsCA,
+		curve:          t.Curve,
+		issuer:         t.issuer,
 	}
 	}
 
 
-	for i, p := range nc.details.Subnets {
-		c.details.Subnets[i] = p
+	return nil
+}
+
+func (c *certificateV1) marshalForSigning() ([]byte, error) {
+	b, err := proto.Marshal(c.getRawDetails())
+	if err != nil {
+		return nil, err
 	}
 	}
+	return b, nil
+}
 
 
-	return c
+func (c *certificateV1) setSignature(b []byte) error {
+	c.signature = b
+	return nil
 }
 }
 
 
 // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
 // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert
-func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) {
+// if the publicKey is provided here then it is not required to be present in `b`
+func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) {
 	if len(b) == 0 {
 	if len(b) == 0 {
 		return nil, fmt.Errorf("nil byte array")
 		return nil, fmt.Errorf("nil byte array")
 	}
 	}
@@ -371,27 +350,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
 
 
 	nc := certificateV1{
 	nc := certificateV1{
 		details: detailsV1{
 		details: detailsV1{
-			Name:      rc.Details.Name,
-			Groups:    make([]string, len(rc.Details.Groups)),
-			Ips:       make([]netip.Prefix, len(rc.Details.Ips)/2),
-			Subnets:   make([]netip.Prefix, len(rc.Details.Subnets)/2),
-			NotBefore: time.Unix(rc.Details.NotBefore, 0),
-			NotAfter:  time.Unix(rc.Details.NotAfter, 0),
-			PublicKey: make([]byte, len(rc.Details.PublicKey)),
-			IsCA:      rc.Details.IsCA,
-			Curve:     rc.Details.Curve,
+			name:           rc.Details.Name,
+			groups:         make([]string, len(rc.Details.Groups)),
+			networks:       make([]netip.Prefix, len(rc.Details.Ips)/2),
+			unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
+			notBefore:      time.Unix(rc.Details.NotBefore, 0),
+			notAfter:       time.Unix(rc.Details.NotAfter, 0),
+			publicKey:      make([]byte, len(rc.Details.PublicKey)),
+			isCA:           rc.Details.IsCA,
+			curve:          rc.Details.Curve,
 		},
 		},
 		signature: make([]byte, len(rc.Signature)),
 		signature: make([]byte, len(rc.Signature)),
 	}
 	}
 
 
 	copy(nc.signature, rc.Signature)
 	copy(nc.signature, rc.Signature)
-	copy(nc.details.Groups, rc.Details.Groups)
-	nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer)
+	copy(nc.details.groups, rc.Details.Groups)
+	nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
 
 
-	if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey {
-		return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey))
+	if len(publicKey) > 0 {
+		nc.details.publicKey = publicKey
 	}
 	}
-	copy(nc.details.PublicKey, rc.Details.PublicKey)
+
+	copy(nc.details.publicKey, rc.Details.PublicKey)
 
 
 	var ip netip.Addr
 	var ip netip.Addr
 	for i, rawIp := range rc.Details.Ips {
 	for i, rawIp := range rc.Details.Ips {
@@ -399,7 +379,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
 			ip = int2addr(rawIp)
 			ip = int2addr(rawIp)
 		} else {
 		} else {
 			ones, _ := net.IPMask(int2ip(rawIp)).Size()
 			ones, _ := net.IPMask(int2ip(rawIp)).Size()
-			nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones)
+			nc.details.networks[i/2] = netip.PrefixFrom(ip, ones)
 		}
 		}
 	}
 	}
 
 
@@ -408,67 +388,13 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err
 			ip = int2addr(rawIp)
 			ip = int2addr(rawIp)
 		} else {
 		} else {
 			ones, _ := net.IPMask(int2ip(rawIp)).Size()
 			ones, _ := net.IPMask(int2ip(rawIp)).Size()
-			nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones)
+			nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones)
 		}
 		}
 	}
 	}
 
 
-	return &nc, nil
-}
-
-func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) {
-	c := &certificateV1{
-		details: detailsV1{
-			Name:      t.Name,
-			Ips:       t.Networks,
-			Subnets:   t.UnsafeNetworks,
-			Groups:    t.Groups,
-			NotBefore: t.NotBefore,
-			NotAfter:  t.NotAfter,
-			PublicKey: t.PublicKey,
-			IsCA:      t.IsCA,
-			Curve:     t.Curve,
-			Issuer:    t.issuer,
-		},
-	}
-	b, err := proto.Marshal(c.getRawDetails())
-	if err != nil {
-		return nil, err
-	}
-
-	var sig []byte
+	//do not sort the subnets field for V1 certs
 
 
-	switch curve {
-	case Curve_CURVE25519:
-		signer := ed25519.PrivateKey(key)
-		sig = ed25519.Sign(signer, b)
-	case Curve_P256:
-		if client != nil {
-			sig, err = client.SignASN1(b)
-		} else {
-			signer := &ecdsa.PrivateKey{
-				PublicKey: ecdsa.PublicKey{
-					Curve: elliptic.P256(),
-				},
-				// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
-				D: new(big.Int).SetBytes(key),
-			}
-			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
-			signer.X, signer.Y = signer.Curve.ScalarBaseMult(key)
-
-			// We need to hash first for ECDSA
-			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
-			hashed := sha256.Sum256(b)
-			sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:])
-			if err != nil {
-				return nil, err
-			}
-		}
-	default:
-		return nil, fmt.Errorf("invalid curve: %s", c.details.Curve)
-	}
-
-	c.signature = sig
-	return c, nil
+	return &nc, nil
 }
 }
 
 
 func ip2int(ip []byte) uint32 {
 func ip2int(ip []byte) uint32 {

+ 37 - 0
cert/cert_v2.asn1

@@ -0,0 +1,37 @@
+Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN
+
+Name ::= UTF8String (SIZE (1..253))
+Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum
+Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length
+Curve ::= ENUMERATED {
+    curve25519 (0),
+    p256 (1)
+}
+
+-- The maximum size of a certificate must not exceed 65536 bytes
+Certificate ::= SEQUENCE {
+    details OCTET STRING,
+    curve Curve DEFAULT curve25519,
+    publicKey OCTET STRING,
+    -- signature(details + curve + publicKey) using the appropriate method for curve
+    signature OCTET STRING
+}
+
+Details ::= SEQUENCE {
+    name Name,
+
+    -- At least 1 ipv4 or ipv6 address must be present if isCA is false
+    networks SEQUENCE OF Network OPTIONAL,
+    unsafeNetworks SEQUENCE OF Network OPTIONAL,
+    groups SEQUENCE OF Name OPTIONAL,
+    isCA BOOLEAN DEFAULT false,
+    notBefore Time,
+    notAfter Time,
+
+    -- issuer is only required if isCA is false, if isCA is true then it must not be present
+    issuer OCTET STRING OPTIONAL,
+    ...
+    -- New fields can be added below here
+}
+
+END

+ 621 - 0
cert/cert_v2.go

@@ -0,0 +1,621 @@
+package cert
+
+import (
+	"bytes"
+	"crypto/ecdh"
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"encoding/pem"
+	"fmt"
+	"net/netip"
+	"slices"
+	"time"
+
+	"golang.org/x/crypto/cryptobyte"
+	"golang.org/x/crypto/cryptobyte/asn1"
+	"golang.org/x/crypto/curve25519"
+)
+
+//TODO: should we avoid hex encoding shit on output? Just let it be base64?
+
+const (
+	classConstructed     = 0x20
+	classContextSpecific = 0x80
+
+	TagCertDetails   = 0 | classConstructed | classContextSpecific
+	TagCertCurve     = 1 | classContextSpecific
+	TagCertPublicKey = 2 | classContextSpecific
+	TagCertSignature = 3 | classContextSpecific
+
+	TagDetailsName           = 0 | classContextSpecific
+	TagDetailsNetworks       = 1 | classConstructed | classContextSpecific
+	TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific
+	TagDetailsGroups         = 3 | classConstructed | classContextSpecific
+	TagDetailsIsCA           = 4 | classContextSpecific
+	TagDetailsNotBefore      = 5 | classContextSpecific
+	TagDetailsNotAfter       = 6 | classContextSpecific
+	TagDetailsIssuer         = 7 | classContextSpecific
+)
+
+const (
+	// MaxCertificateSize is the maximum length a valid certificate can be
+	MaxCertificateSize = 65536
+
+	// MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems
+	MaxNameLength = 253
+
+	// MaxNetworkLength is the maximum length a network value can be.
+	// 16 bytes for an ipv6 address + 1 byte for the prefix length
+	MaxNetworkLength = 17
+)
+
+type certificateV2 struct {
+	details detailsV2
+
+	// RawDetails contains the entire asn.1 DER encoded Details struct
+	// This is to benefit forwards compatibility in signature checking.
+	// signature(RawDetails + Curve + PublicKey) == Signature
+	rawDetails []byte
+	curve      Curve
+	publicKey  []byte
+	signature  []byte
+}
+
+type detailsV2 struct {
+	name           string
+	networks       []netip.Prefix
+	unsafeNetworks []netip.Prefix
+	groups         []string
+	isCA           bool
+	notBefore      time.Time
+	notAfter       time.Time
+	issuer         string
+}
+
+func (c *certificateV2) Version() Version {
+	return Version2
+}
+
+func (c *certificateV2) Curve() Curve {
+	return c.curve
+}
+
+func (c *certificateV2) Groups() []string {
+	return c.details.groups
+}
+
+func (c *certificateV2) IsCA() bool {
+	return c.details.isCA
+}
+
+func (c *certificateV2) Issuer() string {
+	return c.details.issuer
+}
+
+func (c *certificateV2) Name() string {
+	return c.details.name
+}
+
+func (c *certificateV2) Networks() []netip.Prefix {
+	return c.details.networks
+}
+
+func (c *certificateV2) NotAfter() time.Time {
+	return c.details.notAfter
+}
+
+func (c *certificateV2) NotBefore() time.Time {
+	return c.details.notBefore
+}
+
+func (c *certificateV2) PublicKey() []byte {
+	return c.publicKey
+}
+
+func (c *certificateV2) Signature() []byte {
+	return c.signature
+}
+
+func (c *certificateV2) UnsafeNetworks() []netip.Prefix {
+	return c.details.unsafeNetworks
+}
+
+func (c *certificateV2) Fingerprint() (string, error) {
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	//TODO: double check this, panic on empty raw details
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature)
+	sum := sha256.Sum256(b)
+	return hex.EncodeToString(sum[:]), nil
+}
+
+func (c *certificateV2) CheckSignature(key []byte) bool {
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	//TODO: double check this, panic on empty raw details
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+
+	switch c.curve {
+	case Curve_CURVE25519:
+		return ed25519.Verify(key, b, c.signature)
+	case Curve_P256:
+		//TODO: NewPublicKey
+		x, y := elliptic.Unmarshal(elliptic.P256(), key)
+		pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}
+		hashed := sha256.Sum256(b)
+		return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature)
+	default:
+		return false
+	}
+}
+
+func (c *certificateV2) Expired(t time.Time) bool {
+	return c.details.notBefore.After(t) || c.details.notAfter.Before(t)
+}
+
+func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error {
+	if curve != c.curve {
+		return fmt.Errorf("curve in cert and private key supplied don't match")
+	}
+	if c.details.isCA {
+		switch curve {
+		case Curve_CURVE25519:
+			// the call to PublicKey below will panic slice bounds out of range otherwise
+			if len(key) != ed25519.PrivateKeySize {
+				return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key")
+			}
+
+			if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		case Curve_P256:
+			privkey, err := ecdh.P256().NewPrivateKey(key)
+			if err != nil {
+				return fmt.Errorf("cannot parse private key as P256")
+			}
+			pub := privkey.PublicKey().Bytes()
+			if !bytes.Equal(pub, c.publicKey) {
+				return fmt.Errorf("public key in cert and private key supplied don't match")
+			}
+		default:
+			return fmt.Errorf("invalid curve: %s", curve)
+		}
+		return nil
+	}
+
+	var pub []byte
+	switch curve {
+	case Curve_CURVE25519:
+		var err error
+		pub, err = curve25519.X25519(key, curve25519.Basepoint)
+		if err != nil {
+			return err
+		}
+	case Curve_P256:
+		privkey, err := ecdh.P256().NewPrivateKey(key)
+		if err != nil {
+			return err
+		}
+		pub = privkey.PublicKey().Bytes()
+	default:
+		return fmt.Errorf("invalid curve: %s", curve)
+	}
+	if !bytes.Equal(pub, c.publicKey) {
+		return fmt.Errorf("public key in cert and private key supplied don't match")
+	}
+
+	return nil
+}
+
+func (c *certificateV2) String() string {
+	b, err := json.MarshalIndent(c.marshalJSON(), "", "\t")
+	if err != nil {
+		return "<error marshalling certificate>"
+	}
+	return string(b)
+}
+
+func (c *certificateV2) MarshalForHandshakes() ([]byte, error) {
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		//TODO: panic on nil rawDetails
+		b.AddBytes(c.rawDetails)
+
+		// Skipping the curve and public key since those come across in a different part of the handshake
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) Marshal() ([]byte, error) {
+	var b cryptobyte.Builder
+	// Outermost certificate
+	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
+
+		// Add the cert details which is already marshalled
+		b.AddBytes(c.rawDetails)
+
+		// Add the curve only if its not the default value
+		if c.curve != Curve_CURVE25519 {
+			b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) {
+				b.AddBytes([]byte{byte(c.curve)})
+			})
+		}
+
+		// Add the public key if it is not empty
+		if c.publicKey != nil {
+			b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) {
+				b.AddBytes(c.publicKey)
+			})
+		}
+
+		// Add the signature
+		b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) {
+			b.AddBytes(c.signature)
+		})
+	})
+
+	return b.Bytes()
+}
+
+func (c *certificateV2) MarshalPEM() ([]byte, error) {
+	b, err := c.Marshal()
+	if err != nil {
+		return nil, err
+	}
+	return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil
+}
+
+func (c *certificateV2) MarshalJSON() ([]byte, error) {
+	return json.Marshal(c.marshalJSON())
+}
+
+func (c *certificateV2) marshalJSON() m {
+	fp, _ := c.Fingerprint()
+	return m{
+		"details": m{
+			"name":           c.details.name,
+			"networks":       c.details.networks,
+			"unsafeNetworks": c.details.unsafeNetworks,
+			"groups":         c.details.groups,
+			"notBefore":      c.details.notBefore,
+			"notAfter":       c.details.notAfter,
+			"isCa":           c.details.isCA,
+			"issuer":         c.details.issuer,
+		},
+		"version":     Version2,
+		"publicKey":   fmt.Sprintf("%x", c.publicKey),
+		"curve":       c.curve.String(),
+		"fingerprint": fp,
+		"signature":   fmt.Sprintf("%x", c.Signature()),
+	}
+}
+
+func (c *certificateV2) Copy() Certificate {
+	nc := &certificateV2{
+		details: detailsV2{
+			name:           c.details.name,
+			groups:         make([]string, len(c.details.groups)),
+			networks:       make([]netip.Prefix, len(c.details.networks)),
+			unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)),
+			notBefore:      c.details.notBefore,
+			notAfter:       c.details.notAfter,
+			isCA:           c.details.isCA,
+			issuer:         c.details.issuer,
+		},
+		curve:     c.curve,
+		publicKey: make([]byte, len(c.publicKey)),
+		signature: make([]byte, len(c.signature)),
+	}
+
+	copy(nc.signature, c.signature)
+	copy(nc.details.groups, c.details.groups)
+	copy(nc.publicKey, c.publicKey)
+	copy(nc.details.networks, c.details.networks)
+	copy(nc.details.unsafeNetworks, c.details.unsafeNetworks)
+
+	return nc
+}
+
+func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
+	c.details = detailsV2{
+		name:           t.Name,
+		networks:       t.Networks,
+		unsafeNetworks: t.UnsafeNetworks,
+		groups:         t.Groups,
+		isCA:           t.IsCA,
+		notBefore:      t.NotBefore,
+		notAfter:       t.NotAfter,
+		issuer:         t.issuer,
+	}
+	c.curve = t.Curve
+	c.publicKey = t.PublicKey
+	return nil
+}
+
+func (c *certificateV2) marshalForSigning() ([]byte, error) {
+	d, err := c.details.Marshal()
+	if err != nil {
+		//TODO: annotate?
+		return nil, err
+	}
+	c.rawDetails = d
+
+	b := make([]byte, len(c.rawDetails)+1+len(c.publicKey))
+	//TODO: double check this
+	copy(b, c.rawDetails)
+	b[len(c.rawDetails)] = byte(c.curve)
+	copy(b[len(c.rawDetails)+1:], c.publicKey)
+	return b, nil
+}
+
+func (c *certificateV2) setSignature(b []byte) error {
+	c.signature = b
+	return nil
+}
+
+func (d *detailsV2) Marshal() ([]byte, error) {
+	var b cryptobyte.Builder
+	var err error
+
+	// Details are a structure
+	b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) {
+
+		// Add the name
+		b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) {
+			b.AddBytes([]byte(d.name))
+		})
+
+		// Add the networks if any exist
+		if len(d.networks) > 0 {
+			b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.networks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add the unsafe networks if any exist
+		if len(d.unsafeNetworks) > 0 {
+			b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) {
+				for _, n := range d.unsafeNetworks {
+					sb, innerErr := n.MarshalBinary()
+					if innerErr != nil {
+						// MarshalBinary never returns an error
+						err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr)
+						return
+					}
+					b.AddASN1OctetString(sb)
+				}
+			})
+		}
+
+		// Add groups if any exist
+		if len(d.groups) > 0 {
+			b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) {
+				for _, group := range d.groups {
+					b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) {
+						b.AddBytes([]byte(group))
+					})
+				}
+			})
+		}
+
+		// Add IsCA only if true
+		if d.isCA {
+			b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) {
+				b.AddUint8(0xff)
+			})
+		}
+
+		// Add not before
+		b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore)
+
+		// Add not after
+		b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter)
+
+		// Add the issuer if present
+		if d.issuer != "" {
+			issuerBytes, innerErr := hex.DecodeString(d.issuer)
+			if innerErr != nil {
+				err = fmt.Errorf("failed to decode issuer: %w", innerErr)
+				return
+			}
+			b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) {
+				b.AddBytes(issuerBytes)
+			})
+		}
+	})
+
+	if err != nil {
+		return nil, err
+	}
+
+	return b.Bytes()
+}
+
+func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
+	l := len(b)
+	if l == 0 || l > MaxCertificateSize {
+		return nil, ErrBadFormat
+	}
+
+	input := cryptobyte.String(b)
+	// Open the envelope
+	if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Grab the cert details, we need to preserve the tag and length
+	var rawDetails cryptobyte.String
+	if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	//Maybe grab the curve
+	var rawCurve byte
+	if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
+		return nil, ErrBadFormat
+	}
+	curve = Curve(rawCurve)
+
+	// Maybe grab the public key
+	var rawPublicKey cryptobyte.String
+	if len(publicKey) > 0 {
+		rawPublicKey = publicKey
+	} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
+		return nil, ErrBadFormat
+	}
+
+	//TODO: Assert public key length
+
+	// Grab the signature
+	var rawSignature cryptobyte.String
+	if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() {
+		return nil, ErrBadFormat
+	}
+
+	// Finally unmarshal the details
+	details, err := unmarshalDetails(rawDetails)
+	if err != nil {
+		return nil, err
+	}
+
+	return &certificateV2{
+		details:    details,
+		rawDetails: rawDetails,
+		curve:      curve,
+		publicKey:  rawPublicKey,
+		signature:  rawSignature,
+	}, nil
+}
+
+func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
+	// Open the envelope
+	if !b.ReadASN1(&b, TagCertDetails) || b.Empty() {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the name
+	var name cryptobyte.String
+	if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read the network addresses
+	var subString cryptobyte.String
+	var found bool
+
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var networks []netip.Prefix
+	var val cryptobyte.String
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			networks = append(networks, n)
+		}
+	}
+
+	// Read out any unsafe networks
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var unsafeNetworks []netip.Prefix
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength {
+				return detailsV2{}, ErrBadFormat
+			}
+
+			var n netip.Prefix
+			if err := n.UnmarshalBinary(val); err != nil {
+				return detailsV2{}, ErrBadFormat
+			}
+			unsafeNetworks = append(unsafeNetworks, n)
+		}
+	}
+
+	// Read out any groups
+	if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var groups []string
+	if found {
+		for !subString.Empty() {
+			if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() {
+				return detailsV2{}, ErrBadFormat
+			}
+			groups = append(groups, string(val))
+		}
+	}
+
+	// Read out IsCA
+	var isCa bool
+	if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read not before and not after
+	var notBefore int64
+	if !b.ReadASN1Int64WithTag(&notBefore, TagDetailsNotBefore) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	var notAfter int64
+	if !b.ReadASN1Int64WithTag(&notAfter, TagDetailsNotAfter) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	// Read issuer
+	var issuer cryptobyte.String
+	if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) {
+		return detailsV2{}, ErrBadFormat
+	}
+
+	slices.SortFunc(networks, comparePrefix)
+	slices.SortFunc(unsafeNetworks, comparePrefix)
+
+	return detailsV2{
+		name:           string(name),
+		networks:       networks,
+		unsafeNetworks: unsafeNetworks,
+		groups:         groups,
+		isCA:           isCa,
+		notBefore:      time.Unix(notBefore, 0),
+		notAfter:       time.Unix(notAfter, 0),
+		issuer:         hex.EncodeToString(issuer),
+	}, nil
+}

+ 3 - 0
cert/errors.go

@@ -24,4 +24,7 @@ var (
 	ErrInvalidPEMX25519PrivateKeyBanner  = errors.New("bytes did not contain a proper X25519 private key banner")
 	ErrInvalidPEMX25519PrivateKeyBanner  = errors.New("bytes did not contain a proper X25519 private key banner")
 	ErrInvalidPEMEd25519PublicKeyBanner  = errors.New("bytes did not contain a proper Ed25519 public key banner")
 	ErrInvalidPEMEd25519PublicKeyBanner  = errors.New("bytes did not contain a proper Ed25519 public key banner")
 	ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
 	ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner")
+
+	ErrNoPeerStaticKey = errors.New("no peer static key was present")
+	ErrNoPayload       = errors.New("provided payload was empty")
 )
 )

+ 12 - 7
cert/pem.go

@@ -30,19 +30,24 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
 		return nil, r, ErrInvalidPEMBlock
 		return nil, r, ErrInvalidPEMBlock
 	}
 	}
 
 
+	var c Certificate
+	var err error
+
 	switch p.Type {
 	switch p.Type {
 	case CertificateBanner:
 	case CertificateBanner:
-		c, err := unmarshalCertificateV1(p.Bytes, true)
-		if err != nil {
-			return nil, nil, err
-		}
-		return c, r, nil
+		c, err = unmarshalCertificateV1(p.Bytes, nil)
 	case CertificateV2Banner:
 	case CertificateV2Banner:
-		//TODO
-		panic("TODO")
+		c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
 	default:
 	default:
 		return nil, r, ErrInvalidPEMCertificateBanner
 		return nil, r, ErrInvalidPEMCertificateBanner
 	}
 	}
+
+	if err != nil {
+		return nil, r, err
+	}
+
+	return c, r, nil
+
 }
 }
 
 
 func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {
 func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte {

+ 97 - 12
cert/sign.go

@@ -1,11 +1,16 @@
 package cert
 package cert
 
 
 import (
 import (
+	"crypto/ecdsa"
+	"crypto/ed25519"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/sha256"
 	"fmt"
 	"fmt"
+	"math/big"
 	"net/netip"
 	"net/netip"
+	"slices"
 	"time"
 	"time"
-
-	"github.com/slackhq/nebula/pkclient"
 )
 )
 
 
 // TBSCertificate represents a certificate intended to be signed.
 // TBSCertificate represents a certificate intended to be signed.
@@ -24,27 +29,62 @@ type TBSCertificate struct {
 	issuer         string
 	issuer         string
 }
 }
 
 
+type beingSignedCertificate interface {
+	// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
+	fromTBSCertificate(*TBSCertificate) error
+
+	// marshalForSigning returns the bytes that should be signed
+	marshalForSigning() ([]byte, error)
+
+	// setSignature sets the signature for the certificate that has just been signed
+	setSignature([]byte) error
+}
+
+type SignerLambda func(certBytes []byte) ([]byte, error)
+
 // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
 // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those
 // details do not violate constraints of the signing certificate.
 // details do not violate constraints of the signing certificate.
 // If the TBSCertificate is a CA then signer must be nil.
 // If the TBSCertificate is a CA then signer must be nil.
 func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
 func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) {
-	return t.sign(signer, curve, key, nil)
-}
-
-func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) {
-	if curve != Curve_P256 {
-		return nil, fmt.Errorf("only P256 is supported by PKCS#11")
+	switch t.Curve {
+	case Curve_CURVE25519:
+		pk := ed25519.PrivateKey(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			sig := ed25519.Sign(pk, certBytes)
+			return sig, nil
+		}
+		return t.SignWith(signer, curve, sp)
+	case Curve_P256:
+		pk := &ecdsa.PrivateKey{
+			PublicKey: ecdsa.PublicKey{
+				Curve: elliptic.P256(),
+			},
+			// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95
+			D: new(big.Int).SetBytes(key),
+		}
+		// ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119
+		pk.X, pk.Y = pk.Curve.ScalarBaseMult(key)
+		sp := func(certBytes []byte) ([]byte, error) {
+			// We need to hash first for ECDSA
+			// - https://pkg.go.dev/crypto/ecdsa#SignASN1
+			hashed := sha256.Sum256(certBytes)
+			return ecdsa.SignASN1(rand.Reader, pk, hashed[:])
+		}
+		return t.SignWith(signer, curve, sp)
+	default:
+		return nil, fmt.Errorf("invalid curve: %s", t.Curve)
 	}
 	}
-
-	return t.sign(signer, curve, nil, client)
 }
 }
 
 
-func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) {
+// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
+// You should only use SignWith if you do not have direct access to your private key.
+func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
 	if curve != t.Curve {
 	if curve != t.Curve {
 		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
 		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
 	}
 	}
 
 
 	//TODO: make sure we have all minimum properties to sign, like a public key
 	//TODO: make sure we have all minimum properties to sign, like a public key
+	//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs
 
 
 	if signer != nil {
 	if signer != nil {
 		if t.IsCA {
 		if t.IsCA {
@@ -67,10 +107,55 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien
 		}
 		}
 	}
 	}
 
 
+	slices.SortFunc(t.Networks, comparePrefix)
+	slices.SortFunc(t.UnsafeNetworks, comparePrefix)
+
+	var c beingSignedCertificate
 	switch t.Version {
 	switch t.Version {
 	case Version1:
 	case Version1:
-		return signV1(t, curve, key, client)
+		c = &certificateV1{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
+	case Version2:
+		c = &certificateV2{}
+		err := c.fromTBSCertificate(t)
+		if err != nil {
+			return nil, err
+		}
 	default:
 	default:
 		return nil, fmt.Errorf("unknown cert version %d", t.Version)
 		return nil, fmt.Errorf("unknown cert version %d", t.Version)
 	}
 	}
+
+	certBytes, err := c.marshalForSigning()
+	if err != nil {
+		return nil, err
+	}
+
+	sig, err := sp(certBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	//TODO: check if we have sig bytes?
+	err = c.setSignature(sig)
+	if err != nil {
+		return nil, err
+	}
+
+	sc, ok := c.(Certificate)
+	if !ok {
+		return nil, fmt.Errorf("invalid certificate")
+	}
+
+	return sc, nil
+}
+
+func comparePrefix(a, b netip.Prefix) int {
+	addr := a.Addr().Compare(b.Addr())
+	if addr == 0 {
+		return a.Bits() - b.Bits()
+	}
+	return addr
 }
 }

+ 46 - 22
cmd/nebula-cert/ca.go

@@ -27,34 +27,43 @@ type caFlags struct {
 	outCertPath      *string
 	outCertPath      *string
 	outQRPath        *string
 	outQRPath        *string
 	groups           *string
 	groups           *string
-	ips              *string
-	subnets          *string
+	networks         *string
+	unsafeNetworks   *string
 	argonMemory      *uint
 	argonMemory      *uint
 	argonIterations  *uint
 	argonIterations  *uint
 	argonParallelism *uint
 	argonParallelism *uint
 	encryption       *bool
 	encryption       *bool
+	version          *uint
 
 
 	curve  *string
 	curve  *string
 	p11url *string
 	p11url *string
+
+	// Deprecated options
+	ips     *string
+	subnets *string
 }
 }
 
 
 func newCaFlags() *caFlags {
 func newCaFlags() *caFlags {
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)}
 	cf.set.Usage = func() {}
 	cf.set.Usage = func() {}
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
 	cf.name = cf.set.String("name", "", "Required: name of the certificate authority")
+	cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
 	cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use")
-	cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses")
-	cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets")
+	cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks")
+	cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
 	cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)")
 	cf.p11url = p11Flag(cf.set)
 	cf.p11url = p11Flag(cf.set)
+
+	cf.ips = cf.set.String("ips", "", "Deprecated, see -networks")
+	cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &cf
 	return &cf
 }
 }
 
 
@@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 		}
 		}
 	}
 	}
 
 
-	var ips []netip.Prefix
-	if *cf.ips != "" {
-		for _, rs := range strings.Split(*cf.ips, ",") {
+	version := cert.Version(*cf.version)
+	if version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
+	var networks []netip.Prefix
+	if *cf.networks == "" && *cf.ips != "" {
+		// Pull up deprecated -ips flag if needed
+		*cf.networks = *cf.ips
+	}
+
+	if *cf.networks != "" {
+		for _, rs := range strings.Split(*cf.networks, ",") {
 			rs := strings.Trim(rs, " ")
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 			if rs != "" {
 				n, err := netip.ParsePrefix(rs)
 				n, err := netip.ParsePrefix(rs)
 				if err != nil {
 				if err != nil {
-					return newHelpErrorf("invalid ip definition: %s", err)
+					return newHelpErrorf("invalid -networks definition: %s", rs)
 				}
 				}
-				if !n.Addr().Is4() {
-					return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
 				}
-				ips = append(ips, n)
+				networks = append(networks, n)
 			}
 			}
 		}
 		}
 	}
 	}
 
 
-	var subnets []netip.Prefix
-	if *cf.subnets != "" {
-		for _, rs := range strings.Split(*cf.subnets, ",") {
+	var unsafeNetworks []netip.Prefix
+	if *cf.unsafeNetworks == "" && *cf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*cf.unsafeNetworks = *cf.subnets
+	}
+
+	if *cf.unsafeNetworks != "" {
+		for _, rs := range strings.Split(*cf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 			if rs != "" {
 				n, err := netip.ParsePrefix(rs)
 				n, err := netip.ParsePrefix(rs)
 				if err != nil {
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", err)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
 				}
-				if !n.Addr().Is4() {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+				if version == cert.Version1 && !n.Addr().Is4() {
+					return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs)
 				}
 				}
-				subnets = append(subnets, n)
+				unsafeNetworks = append(unsafeNetworks, n)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 	}
 	}
 
 
 	t := &cert.TBSCertificate{
 	t := &cert.TBSCertificate{
-		Version:        cert.Version1,
+		Version:        version,
 		Name:           *cf.name,
 		Name:           *cf.name,
 		Groups:         groups,
 		Groups:         groups,
-		Networks:       ips,
-		UnsafeNetworks: subnets,
+		Networks:       networks,
+		UnsafeNetworks: unsafeNetworks,
 		NotBefore:      time.Now(),
 		NotBefore:      time.Now(),
 		NotAfter:       time.Now().Add(*cf.duration),
 		NotAfter:       time.Now().Add(*cf.duration),
 		PublicKey:      pub,
 		PublicKey:      pub,
@@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error
 	var b []byte
 	var b []byte
 
 
 	if isP11 {
 	if isP11 {
-		c, err = t.SignPkcs11(nil, curve, p11Client)
+		c, err = t.SignWith(nil, curve, p11Client.SignASN1)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 			return fmt.Errorf("error while signing with PKCS#11: %w", err)
 		}
 		}

+ 20 - 14
cmd/nebula-cert/ca_test.go

@@ -43,9 +43,11 @@ func Test_caHelp(t *testing.T) {
 			"  -groups string\n"+
 			"  -groups string\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"    \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+
 			"  -ips string\n"+
 			"  -ips string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+
+			"    	Deprecated, see -networks\n"+
 			"  -name string\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the certificate authority\n"+
 			"    \tRequired: name of the certificate authority\n"+
+			"  -networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+
 			"  -out-crt string\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"    \tOptional: path to write the certificate to (default \"ca.crt\")\n"+
 			"  -out-key string\n"+
 			"  -out-key string\n"+
@@ -54,7 +56,11 @@ func Test_caHelp(t *testing.T) {
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use (default 2)\n",
 		ob.String(),
 		ob.String(),
 	)
 	)
 }
 }
@@ -83,25 +89,25 @@ func Test_ca(t *testing.T) {
 
 
 	// required args
 	// required args
 	assertHelpError(t, ca(
 	assertHelpError(t, ca(
-		[]string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
+		[]string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw,
 	), "-name is required")
 	), "-name is required")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// ipv4 only ips
 	// ipv4 only ips
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// ipv4 only subnets
 	// ipv4 only subnets
-	assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100")
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
 
 
 	// failed key write
 	// failed key write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
+	args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -114,7 +120,7 @@ func Test_ca(t *testing.T) {
 	// failed cert write
 	// failed cert write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -128,7 +134,7 @@ func Test_ca(t *testing.T) {
 	// test proper cert with removed empty groups and subnets
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -161,7 +167,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, testpw))
 	assert.Nil(t, ca(args, ob, eb, testpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -189,7 +195,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Error(t, ca(args, ob, eb, errpw))
 	assert.Error(t, ca(args, ob, eb, errpw))
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, pwPromptOb, ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -199,7 +205,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -209,13 +215,13 @@ func Test_ca(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.Nil(t, ca(args, ob, eb, nopw))
 	assert.Nil(t, ca(args, ob, eb, nopw))
 
 
 	// test that we won't overwrite existing certificate file
 	// test that we won't overwrite existing certificate file
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -224,7 +230,7 @@ func Test_ca(t *testing.T) {
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
+	args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,,   2    ,        ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())

+ 11 - 6
cmd/nebula-cert/print.go

@@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 	var qrBytes []byte
 	var qrBytes []byte
 	part := 0
 	part := 0
 
 
+	var jsonCerts []cert.Certificate
+
 	for {
 	for {
 		c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
 		c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert)
 		if err != nil {
 		if err != nil {
@@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		}
 		}
 
 
 		if *pf.json {
 		if *pf.json {
-			b, _ := json.Marshal(c)
-			out.Write(b)
-			out.Write([]byte("\n"))
-
+			jsonCerts = append(jsonCerts, c)
 		} else {
 		} else {
-			out.Write([]byte(c.String()))
-			out.Write([]byte("\n"))
+			_, _ = out.Write([]byte(c.String()))
+			_, _ = out.Write([]byte("\n"))
 		}
 		}
 
 
 		if *pf.outQRPath != "" {
 		if *pf.outQRPath != "" {
@@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error {
 		part++
 		part++
 	}
 	}
 
 
+	if *pf.json {
+		b, _ := json.Marshal(jsonCerts)
+		_, _ = out.Write(b)
+		_, _ = out.Write([]byte("\n"))
+	}
+
 	if *pf.outQRPath != "" {
 	if *pf.outQRPath != "" {
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5)
 		if err != nil {
 		if err != nil {

+ 61 - 2
cmd/nebula-cert/print_test.go

@@ -87,7 +87,65 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(
 	assert.Equal(
 		t,
 		t,
-		"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
+		`{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+{
+	"details": {
+		"curve": "CURVE25519",
+		"groups": [
+			"hi"
+		],
+		"isCa": false,
+		"issuer": "`+c.Issuer()+`",
+		"name": "test",
+		"networks": [],
+		"notAfter": "0001-01-01T00:00:00Z",
+		"notBefore": "0001-01-01T00:00:00Z",
+		"publicKey": "`+pk+`",
+		"unsafeNetworks": []
+	},
+	"fingerprint": "`+fp+`",
+	"signature": "`+sig+`",
+	"version": 1
+}
+`,
 		ob.String(),
 		ob.String(),
 	)
 	)
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())
@@ -108,7 +166,8 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(
 	assert.Equal(
 		t,
 		t,
-		"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n",
+		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
+`,
 		ob.String(),
 		ob.String(),
 	)
 	)
 	assert.Equal(t, "", eb.String())
 	assert.Equal(t, "", eb.String())

+ 168 - 66
cmd/nebula-cert/sign.go

@@ -3,6 +3,7 @@ package main
 import (
 import (
 	"crypto/ecdh"
 	"crypto/ecdh"
 	"crypto/rand"
 	"crypto/rand"
+	"errors"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
@@ -18,36 +19,46 @@ import (
 )
 )
 
 
 type signFlags struct {
 type signFlags struct {
-	set         *flag.FlagSet
-	caKeyPath   *string
-	caCertPath  *string
-	name        *string
-	ip          *string
-	duration    *time.Duration
-	inPubPath   *string
-	outKeyPath  *string
-	outCertPath *string
-	outQRPath   *string
-	groups      *string
-	subnets     *string
-	p11url      *string
+	set            *flag.FlagSet
+	version        *uint
+	caKeyPath      *string
+	caCertPath     *string
+	name           *string
+	networks       *string
+	unsafeNetworks *string
+	duration       *time.Duration
+	inPubPath      *string
+	outKeyPath     *string
+	outCertPath    *string
+	outQRPath      *string
+	groups         *string
+
+	p11url *string
+
+	// Deprecated options
+	ip      *string
+	subnets *string
 }
 }
 
 
 func newSignFlags() *signFlags {
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
 	sf.set.Usage = func() {}
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
-	sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert")
+	sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert")
+	sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
 	sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups")
-	sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for")
 	sf.p11url = p11Flag(sf.set)
 	sf.p11url = p11Flag(sf.set)
+
+	sf.ip = sf.set.String("ip", "", "Deprecated, see -networks")
+	sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks")
 	return &sf
 	return &sf
 }
 }
 
 
@@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	if err := mustFlagString("name", sf.name); err != nil {
 	if err := mustFlagString("name", sf.name); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := mustFlagString("ip", sf.ip); err != nil {
-		return err
-	}
 	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 	if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" {
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 		return newHelpErrorf("cannot set both -in-pub and -out-key")
 	}
 	}
 
 
+	var v4Networks []netip.Prefix
+	var v6Networks []netip.Prefix
+	if *sf.networks == "" && *sf.ip != "" {
+		// Pull up deprecated -ip flag if needed
+		*sf.networks = *sf.ip
+	}
+
+	if len(*sf.networks) == 0 {
+		return newHelpErrorf("-networks is required")
+	}
+
+	version := cert.Version(*sf.version)
+	if version != 0 && version != cert.Version1 && version != cert.Version2 {
+		return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2)
+	}
+
 	var curve cert.Curve
 	var curve cert.Curve
 	var caKey []byte
 	var caKey []byte
 
 
@@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 
 
 		// naively attempt to decode the private key as though it is not encrypted
 		// naively attempt to decode the private key as though it is not encrypted
 		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
 		caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey)
-		if err == cert.ErrPrivateKeyEncrypted {
+		if errors.Is(err, cert.ErrPrivateKeyEncrypted) {
 			// ask for a passphrase until we get one
 			// ask for a passphrase until we get one
 			var passphrase []byte
 			var passphrase []byte
 			for i := 0; i < 5; i++ {
 			for i := 0; i < 5; i++ {
 				out.Write([]byte("Enter passphrase: "))
 				out.Write([]byte("Enter passphrase: "))
 				passphrase, err = pr.ReadPassword()
 				passphrase, err = pr.ReadPassword()
 
 
-				if err == ErrNoTerminal {
+				if errors.Is(err, ErrNoTerminal) {
 					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
 					return fmt.Errorf("ca-key is encrypted and must be decrypted interactively")
 				} else if err != nil {
 				} else if err != nil {
 					return fmt.Errorf("error reading password: %s", err)
 					return fmt.Errorf("error reading password: %s", err)
@@ -146,37 +170,57 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
 		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
 	}
 	}
 
 
-	network, err := netip.ParsePrefix(*sf.ip)
-	if err != nil {
-		return newHelpErrorf("invalid ip definition: %s", *sf.ip)
-	}
-	if !network.Addr().Is4() {
-		return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip)
-	}
+	if *sf.networks != "" {
+		for _, rs := range strings.Split(*sf.networks, ",") {
+			//TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange
+			rs := strings.Trim(rs, " ")
+			if rs != "" {
+				n, err := netip.ParsePrefix(rs)
+				if err != nil {
+					return newHelpErrorf("invalid -networks definition: %s", rs)
+				}
 
 
-	var groups []string
-	if *sf.groups != "" {
-		for _, rg := range strings.Split(*sf.groups, ",") {
-			g := strings.TrimSpace(rg)
-			if g != "" {
-				groups = append(groups, g)
+				if n.Addr().Is4() {
+					v4Networks = append(v4Networks, n)
+				} else {
+					v6Networks = append(v6Networks, n)
+				}
 			}
 			}
 		}
 		}
 	}
 	}
 
 
-	var subnets []netip.Prefix
-	if *sf.subnets != "" {
-		for _, rs := range strings.Split(*sf.subnets, ",") {
+	var v4UnsafeNetworks []netip.Prefix
+	var v6UnsafeNetworks []netip.Prefix
+	if *sf.unsafeNetworks == "" && *sf.subnets != "" {
+		// Pull up deprecated -subnets flag if needed
+		*sf.unsafeNetworks = *sf.subnets
+	}
+
+	if *sf.unsafeNetworks != "" {
+		//TODO: error on duplicates?
+		for _, rs := range strings.Split(*sf.unsafeNetworks, ",") {
 			rs := strings.Trim(rs, " ")
 			rs := strings.Trim(rs, " ")
 			if rs != "" {
 			if rs != "" {
-				s, err := netip.ParsePrefix(rs)
+				n, err := netip.ParsePrefix(rs)
 				if err != nil {
 				if err != nil {
-					return newHelpErrorf("invalid subnet definition: %s", rs)
+					return newHelpErrorf("invalid -unsafe-networks definition: %s", rs)
 				}
 				}
-				if !s.Addr().Is4() {
-					return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs)
+
+				if n.Addr().Is4() {
+					v4UnsafeNetworks = append(v4UnsafeNetworks, n)
+				} else {
+					v6UnsafeNetworks = append(v6UnsafeNetworks, n)
 				}
 				}
-				subnets = append(subnets, s)
+			}
+		}
+	}
+
+	var groups []string
+	if *sf.groups != "" {
+		for _, rg := range strings.Split(*sf.groups, ",") {
+			g := strings.TrimSpace(rg)
+			if g != "" {
+				groups = append(groups, g)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -218,19 +262,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		pub, rawPriv = newKeypair(curve)
 		pub, rawPriv = newKeypair(curve)
 	}
 	}
 
 
-	t := &cert.TBSCertificate{
-		Version:        cert.Version1,
-		Name:           *sf.name,
-		Networks:       []netip.Prefix{network},
-		Groups:         groups,
-		UnsafeNetworks: subnets,
-		NotBefore:      time.Now(),
-		NotAfter:       time.Now().Add(*sf.duration),
-		PublicKey:      pub,
-		IsCA:           false,
-		Curve:          curve,
-	}
-
 	if *sf.outKeyPath == "" {
 	if *sf.outKeyPath == "" {
 		*sf.outKeyPath = *sf.name + ".key"
 		*sf.outKeyPath = *sf.name + ".key"
 	}
 	}
@@ -243,18 +274,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 		return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath)
 	}
 	}
 
 
-	var c cert.Certificate
+	var crts []cert.Certificate
 
 
-	if p11Client == nil {
-		c, err = t.Sign(caCert, curve, caKey)
-		if err != nil {
-			return fmt.Errorf("error while signing: %w", err)
+	notBefore := time.Now()
+	notAfter := notBefore.Add(*sf.duration)
+
+	if version == 0 || version == cert.Version1 {
+		// Make sure we at least have an ip
+		if len(v4Networks) != 1 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
 		}
 		}
-	} else {
-		c, err = t.SignPkcs11(caCert, curve, p11Client)
-		if err != nil {
-			return fmt.Errorf("error while signing with PKCS#11: %w", err)
+
+		if version == cert.Version1 {
+			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
+			if len(v6Networks) > 0 {
+				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
+			}
+
+			if len(v6UnsafeNetworks) > 0 {
+				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+			}
+		}
+
+		t := &cert.TBSCertificate{
+			Version:        cert.Version1,
+			Name:           *sf.name,
+			Networks:       []netip.Prefix{v4Networks[0]},
+			Groups:         groups,
+			UnsafeNetworks: v4UnsafeNetworks,
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
 		}
 		}
+
+		crts = append(crts, nc)
+	}
+
+	if version == 0 || version == cert.Version2 {
+		t := &cert.TBSCertificate{
+			Version:        cert.Version2,
+			Name:           *sf.name,
+			Networks:       append(v4Networks, v6Networks...),
+			Groups:         groups,
+			UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...),
+			NotBefore:      notBefore,
+			NotAfter:       notAfter,
+			PublicKey:      pub,
+			IsCA:           false,
+			Curve:          curve,
+		}
+
+		var nc cert.Certificate
+		if p11Client == nil {
+			nc, err = t.Sign(caCert, curve, caKey)
+			if err != nil {
+				return fmt.Errorf("error while signing: %w", err)
+			}
+		} else {
+			nc, err = t.SignWith(caCert, curve, p11Client.SignASN1)
+			if err != nil {
+				return fmt.Errorf("error while signing with PKCS#11: %w", err)
+			}
+		}
+
+		crts = append(crts, nc)
 	}
 	}
 
 
 	if !isP11 && *sf.inPubPath == "" {
 	if !isP11 && *sf.inPubPath == "" {
@@ -268,9 +366,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		}
 		}
 	}
 	}
 
 
-	b, err := c.MarshalPEM()
-	if err != nil {
-		return fmt.Errorf("error while marshalling certificate: %s", err)
+	var b []byte
+	for _, c := range crts {
+		sb, err := c.MarshalPEM()
+		if err != nil {
+			return fmt.Errorf("error while marshalling certificate: %s", err)
+		}
+		b = append(b, sb...)
 	}
 	}
 
 
 	err = os.WriteFile(*sf.outCertPath, b, 0600)
 	err = os.WriteFile(*sf.outCertPath, b, 0600)

+ 47 - 34
cmd/nebula-cert/sign_test.go

@@ -39,9 +39,11 @@ func Test_signHelp(t *testing.T) {
 			"  -in-pub string\n"+
 			"  -in-pub string\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"    \tOptional (if out-key not set): path to read a previously generated public key\n"+
 			"  -ip string\n"+
 			"  -ip string\n"+
-			"    \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+
+			"    \tDeprecated, see -networks\n"+
 			"  -name string\n"+
 			"  -name string\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
 			"    \tRequired: name of the cert, usually a hostname\n"+
+			"  -networks string\n"+
+			"    \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+
 			"  -out-crt string\n"+
 			"  -out-crt string\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"    \tOptional: path to write the certificate to\n"+
 			"  -out-key string\n"+
 			"  -out-key string\n"+
@@ -50,7 +52,11 @@ func Test_signHelp(t *testing.T) {
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			"    \tOptional: output a qr code image (png) of the certificate\n"+
 			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			optionalPkcs11String("  -pkcs11 string\n    \tOptional: PKCS#11 URI to an existing private key\n")+
 			"  -subnets string\n"+
 			"  -subnets string\n"+
-			"    \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n",
+			"    \tDeprecated, see -unsafe-networks\n"+
+			"  -unsafe-networks string\n"+
+			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
+			"  -version uint\n"+
+			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
 		ob.String(),
 		ob.String(),
 	)
 	)
 }
 }
@@ -77,20 +83,20 @@ func Test_signCert(t *testing.T) {
 
 
 	// required args
 	// required args
 	assertHelpError(t, signCert(
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
 	), "-name is required")
 	), "-name is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	assertHelpError(t, signCert(
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
-	), "-ip is required")
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw,
+	), "-networks is required")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// cannot set -in-pub and -out-key
 	// cannot set -in-pub and -out-key
 	assertHelpError(t, signCert(
 	assertHelpError(t, signCert(
-		[]string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
+		[]string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw,
 	), "cannot set both -in-pub and -out-key")
 	), "cannot set both -in-pub and -out-key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -98,7 +104,7 @@ func Test_signCert(t *testing.T) {
 	// failed to read key
 	// failed to read key
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
 
 
 	// failed to unmarshal key
 	// failed to unmarshal key
@@ -108,7 +114,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(caKeyF.Name())
 	defer os.Remove(caKeyF.Name())
 
 
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -120,7 +126,7 @@ func Test_signCert(t *testing.T) {
 	caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
 	caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv))
 
 
 	// failed to read cert
 	// failed to read cert
-	args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -132,7 +138,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(caCrtF.Name())
 	defer os.Remove(caCrtF.Name())
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -143,7 +149,7 @@ func Test_signCert(t *testing.T) {
 	caCrtF.Write(b)
 	caCrtF.Write(b)
 
 
 	// failed to read pub
 	// failed to read pub
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -155,7 +161,7 @@ func Test_signCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	defer os.Remove(inPubF.Name())
 	defer os.Remove(inPubF.Name())
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -169,30 +175,37 @@ func Test_signCert(t *testing.T) {
 	// bad ip cidr
 	// bad ip cidr
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
+	assert.Empty(t, ob.String())
+	assert.Empty(t, eb.String())
+
+	ob.Reset()
+	eb.Reset()
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	// bad subnet cidr
 	// bad subnet cidr
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100")
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
 
 
@@ -205,7 +218,7 @@ func Test_signCert(t *testing.T) {
 
 
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -213,7 +226,7 @@ func Test_signCert(t *testing.T) {
 	// failed key write
 	// failed key write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -226,7 +239,7 @@ func Test_signCert(t *testing.T) {
 	// failed cert write
 	// failed cert write
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -240,7 +253,7 @@ func Test_signCert(t *testing.T) {
 	// test proper cert with removed empty groups and subnets
 	// test proper cert with removed empty groups and subnets
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -283,7 +296,7 @@ func Test_signCert(t *testing.T) {
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -300,7 +313,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 	eb.Reset()
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -308,14 +321,14 @@ func Test_signCert(t *testing.T) {
 	// create valid cert/key for overwrite tests
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 
 	// test that we won't overwrite existing key file
 	// test that we won't overwrite existing key file
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -323,14 +336,14 @@ func Test_signCert(t *testing.T) {
 	// create valid cert/key for overwrite tests
 	// create valid cert/key for overwrite tests
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	os.Remove(crtF.Name())
 	os.Remove(crtF.Name())
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 	assert.Nil(t, signCert(args, ob, eb, nopw))
 
 
 	// test that we won't overwrite existing certificate file
 	// test that we won't overwrite existing certificate file
 	os.Remove(keyF.Name())
 	os.Remove(keyF.Name())
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -362,7 +375,7 @@ func Test_signCert(t *testing.T) {
 	caCrtF.Write(b)
 	caCrtF.Write(b)
 
 
 	// test with the proper password
 	// test with the proper password
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Nil(t, signCert(args, ob, eb, testpw))
 	assert.Nil(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -372,7 +385,7 @@ func Test_signCert(t *testing.T) {
 	eb.Reset()
 	eb.Reset()
 
 
 	testpw.password = []byte("invalid password")
 	testpw.password = []byte("invalid password")
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, testpw))
 	assert.Error(t, signCert(args, ob, eb, testpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())
@@ -381,7 +394,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, nopw))
 	assert.Error(t, signCert(args, ob, eb, nopw))
 	// normally the user hitting enter on the prompt would add newlines between these
 	// normally the user hitting enter on the prompt would add newlines between these
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
@@ -391,7 +404,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	ob.Reset()
 	eb.Reset()
 	eb.Reset()
 
 
-	args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
+	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, ,   10.2.2.2/32   ,   ,  ,, 10.5.5.5/32", "-groups", "1,,   2    ,        ,,,3,4,5"}
 	assert.Error(t, signCert(args, ob, eb, errpw))
 	assert.Error(t, signCert(args, ob, eb, errpw))
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Equal(t, "Enter passphrase: ", ob.String())
 	assert.Empty(t, eb.String())
 	assert.Empty(t, eb.String())

+ 53 - 29
connection_manager.go

@@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 	case deleteTunnel:
 	case deleteTunnel:
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 		if n.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 		}
 
 
 	case closeTunnel:
 	case closeTunnel:
@@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 
 	for _, r := range relayFor {
 	for _, r := range relayFor {
-		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr)
 
 
 		var index uint32
 		var index uint32
 		var relayFrom netip.Addr
 		var relayFrom netip.Addr
@@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			index = existing.LocalIndex
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = existing.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = existing.PeerAddr
 			case ForwardingType:
 			case ForwardingType:
-				relayFrom = existing.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = existing.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 			default:
 				// should never happen
 				// should never happen
 			}
 			}
@@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			n.relayUsedLock.RUnlock()
 			n.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
 			if err != nil {
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 				continue
 			}
 			}
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
-				relayTo = r.PeerIp
+				relayFrom = n.intf.myVpnAddrs[0]
+				relayTo = r.PeerAddr
 			case ForwardingType:
 			case ForwardingType:
-				relayFrom = r.PeerIp
-				relayTo = newhostinfo.vpnIp
+				relayFrom = r.PeerAddr
+				relayTo = newhostinfo.vpnAddrs[0]
 			default:
 			default:
 				// should never happen
 				// should never happen
 			}
 			}
 		}
 		}
 
 
-		//TODO: IPV6-WORK
-		relayFromB := relayFrom.As4()
-		relayToB := relayTo.As4()
-
 		// Send a CreateRelayRequest to the peer.
 		// Send a CreateRelayRequest to the peer.
 		req := NebulaControl{
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         binary.BigEndian.Uint32(relayFromB[:]),
-			RelayToIp:           binary.BigEndian.Uint32(relayToB[:]),
 		}
 		}
+
+		switch newhostinfo.GetCert().Certificate.Version() {
+		case cert.Version1:
+			if !relayFrom.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				continue
+			}
+
+			if !relayTo.Is4() {
+				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				continue
+			}
+
+			b := relayFrom.As4()
+			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = relayTo.As4()
+			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		case cert.Version2:
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
+			req.RelayToAddr = netAddrToProtoAddr(relayTo)
+		default:
+			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			continue
+		}
+
 		msg, err := req.Marshal()
 		msg, err := req.Marshal()
 		if err != nil {
 		if err != nil {
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           req.RelayFromIp,
-				"relayTo":             req.RelayToIp,
+				"relayFrom":           req.RelayFromAddr,
+				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
-				"vpnIp":               newhostinfo.vpnIp}).
+				"vpnAddrs":            newhostinfo.vpnAddrs}).
 				Info("send CreateRelayRequest")
 				Info("send CreateRelayRequest")
 		}
 		}
 	}
 	}
@@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		return closeTunnel, hostinfo, nil
 		return closeTunnel, hostinfo, nil
 	}
 	}
 
 
-	primary := n.hostMap.Hosts[hostinfo.vpnIp]
+	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
 		mainHostInfo = false
@@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 	// Let's sort this out.
 
 
-	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
+	//TODO: current.vpnIp should become an array of vpnIps
+	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// The remotes vpn ip is lower than mine. I will not flip.
 		// The remotes vpn ip is lower than mine. I will not flip.
 		return false
 		return false
 	}
 	}
 
 
-	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature())
+	//TODO: we should favor v2 over v1 certificates if configured to send them
+
+	crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version())
+	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 }
 
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 	n.hostMap.Lock()
 	n.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnIp] == primary {
+	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
 		n.hostMap.unlockedMakePrimary(current)
 		n.hostMap.unlockedMakePrimary(current)
 	}
 	}
 	n.hostMap.Unlock()
 	n.hostMap.Unlock()
@@ -473,14 +495,16 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 }
 
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) {
+	crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version())
+	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) {
 		return
 		return
 	}
 	}
 
 
-	n.l.WithField("vpnIp", hostinfo.vpnIp).
+	//TODO: we should favor v2 over v1 certificates if configured to send them
+
+	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 		Info("Re-handshaking with remote")
 
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
+	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }
 }

+ 26 - 29
connection_manager_test.go

@@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
@@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 		remoteIndexId: 9901,
 	}
 	}
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &dummyCert{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 		H:      &noise.HandshakeState{},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 	assert.Contains(t, nc.out, hostinfo.localIndexId)
 
 
@@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 
 	// Do a final traffic check tick, the host should now be removed
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 }
 
 
 func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
@@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		localIndexId:  1099,
 		localIndexId:  1099,
 		remoteIndexId: 9901,
 		remoteIndexId: 9901,
 	}
 	}
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
-		myCert: &dummyCert{},
+		myCert: &dummyCert{version: cert.Version1},
 		H:      &noise.HandshakeState{},
 		H:      &noise.HandshakeState{},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
@@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 
 	// We saw traffic, should no longer be pending deletion
 	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
@@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.out, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
 }
 
 
 // Check if we can disconnect the peer.
 // Check if we can disconnect the peer.
@@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	vpnIp := netip.MustParseAddr("172.1.1.2")
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
-	hostMap := newHostMap(l, vpncidr)
+	hostMap := newHostMap(l)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 
 	// Generate keys for CA and peer's cert.
 	// Generate keys for CA and peer's cert.
@@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 	cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
@@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.connectionManager = nc
 	ifce.connectionManager = nc
 
 
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp: vpnIp,
+		vpnAddrs: []netip.Addr{vpnIp},
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			myCert:   &dummyCert{},
 			myCert:   &dummyCert{},
 			peerCert: cachedPeerCert,
 			peerCert: cachedPeerCert,

+ 26 - 21
connection_state.go

@@ -3,6 +3,7 @@ package nebula
 import (
 import (
 	"crypto/rand"
 	"crypto/rand"
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 
 
@@ -26,46 +27,46 @@ type ConnectionState struct {
 	writeLock      sync.Mutex
 	writeLock      sync.Mutex
 }
 }
 
 
-func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
 	var dhFunc noise.DHFunc
 	var dhFunc noise.DHFunc
-	switch certState.Certificate.Curve() {
+	switch crt.Curve() {
 	case cert.Curve_CURVE25519:
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
 	case cert.Curve_P256:
-		if certState.pkcs11Backed {
+		if cs.pkcs11Backed {
 			dhFunc = noiseutil.DHP256PKCS11
 			dhFunc = noiseutil.DHP256PKCS11
 		} else {
 		} else {
 			dhFunc = noiseutil.DHP256
 			dhFunc = noiseutil.DHP256
 		}
 		}
 	default:
 	default:
-		l.Errorf("invalid curve: %s", certState.Certificate.Curve())
-		return nil
+		return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
 	}
 	}
 
 
-	var cs noise.CipherSuite
-	if cipher == "chachapoly" {
-		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	var ncs noise.CipherSuite
+	if cs.cipher == "chachapoly" {
+		ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	} else {
 	} else {
-		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
+		ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 	}
 
 
-	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
+	static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
 
 
 	b := NewBits(ReplayWindow)
 	b := NewBits(ReplayWindow)
-	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
+	// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
 	b.Update(l, 0)
 	b.Update(l, 0)
 
 
 	hs, err := noise.NewHandshakeState(noise.Config{
 	hs, err := noise.NewHandshakeState(noise.Config{
-		CipherSuite:           cs,
-		Random:                rand.Reader,
-		Pattern:               pattern,
-		Initiator:             initiator,
-		StaticKeypair:         static,
-		PresharedKey:          psk,
-		PresharedKeyPlacement: pskStage,
+		CipherSuite:   ncs,
+		Random:        rand.Reader,
+		Pattern:       pattern,
+		Initiator:     initiator,
+		StaticKeypair: static,
+		//NOTE: These should come from CertState (pki.go) when we finally implement it
+		PresharedKey:          []byte{},
+		PresharedKeyPlacement: 0,
 	})
 	})
 	if err != nil {
 	if err != nil {
-		return nil
+		return nil, fmt.Errorf("NewConnectionState: %s", err)
 	}
 	}
 
 
 	// The queue and ready params prevent a counter race that would happen when
 	// The queue and ready params prevent a counter race that would happen when
@@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		H:         hs,
 		initiator: initiator,
 		initiator: initiator,
 		window:    b,
 		window:    b,
-		myCert:    certState.Certificate,
+		myCert:    crt,
 	}
 	}
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
 	// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
 	ci.messageCounter.Add(2)
 	ci.messageCounter.Add(2)
 
 
-	return ci
+	return ci, nil
 }
 }
 
 
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
@@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"message_counter": cs.messageCounter.Load(),
 		"message_counter": cs.messageCounter.Load(),
 	})
 	})
 }
 }
+
+func (cs *ConnectionState) Curve() cert.Curve {
+	return cs.myCert.Curve()
+}

+ 25 - 19
control.go

@@ -19,9 +19,9 @@ import (
 type controlEach func(h *HostInfo)
 type controlEach func(h *HostInfo)
 
 
 type controlHostLister interface {
 type controlHostLister interface {
-	QueryVpnIp(vpnIp netip.Addr) *HostInfo
+	QueryVpnAddr(vpnAddr netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
 	ForEachIndex(each controlEach)
-	ForEachVpnIp(each controlEach)
+	ForEachVpnAddr(each controlEach)
 	GetPreferredRanges() []netip.Prefix
 	GetPreferredRanges() []netip.Prefix
 }
 }
 
 
@@ -37,7 +37,7 @@ type Control struct {
 }
 }
 
 
 type ControlHostInfo struct {
 type ControlHostInfo struct {
-	VpnIp                  netip.Addr       `json:"vpnIp"`
+	VpnAddrs               []netip.Addr     `json:"vpnAddrs"`
 	LocalIndex             uint32           `json:"localIndex"`
 	LocalIndex             uint32           `json:"localIndex"`
 	RemoteIndex            uint32           `json:"remoteIndex"`
 	RemoteIndex            uint32           `json:"remoteIndex"`
 	RemoteAddrs            []netip.AddrPort `json:"remoteAddrs"`
 	RemoteAddrs            []netip.AddrPort `json:"remoteAddrs"`
@@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 
 
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
 // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found
 func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
 func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
-	if c.f.myVpnNet.Addr() == vpnIp {
-		return c.f.pki.GetCertState().Certificate.Copy()
+	_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
+	if found {
+		//TODO: we might have 2 certs....
+		//TODO: this should return our latest version cert
+		return c.f.pki.getDefaultCertificate().Copy()
 	}
 	}
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 	if hi == nil {
 		return nil
 		return nil
 	}
 	}
@@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) {
 
 
 // PrintTunnel creates a new tunnel to the given vpn ip.
 // PrintTunnel creates a new tunnel to the given vpn ip.
 func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
 func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo {
-	hi := c.f.hostMap.QueryVpnIp(vpnIp)
+	hi := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hi == nil {
 	if hi == nil {
 		return nil
 		return nil
 	}
 	}
@@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap {
 	return hi.CopyCache()
 	return hi.CopyCache()
 }
 }
 
 
-// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
+// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
-func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
+func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	var hl controlHostLister
 	if pending {
 	if pending {
 		hl = c.f.handshakeManager
 		hl = c.f.handshakeManager
@@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 		hl = c.f.hostMap
 		hl = c.f.hostMap
 	}
 	}
 
 
-	h := hl.QueryVpnIp(vpnIp)
+	h := hl.QueryVpnAddr(vpnAddr)
 	if h == nil {
 	if h == nil {
 		return nil
 		return nil
 	}
 	}
@@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
 func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return nil
 		return nil
 	}
 	}
@@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 // Caller should take care to Unmap() any 4in6 addresses prior to calling.
 func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
-	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return false
 		return false
 	}
 	}
@@ -229,14 +232,14 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 
 
 	shutdown := func(h *HostInfo) {
 	shutdown := func(h *HostInfo) {
 		if excludeLighthouses {
 		if excludeLighthouses {
-			if _, ok := lighthouses[h.vpnIp]; ok {
+			if _, ok := lighthouses[h.vpnAddrs[0]]; ok {
 				return
 				return
 			}
 			}
 		}
 		}
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 		c.f.closeTunnel(h)
 		c.f.closeTunnel(h)
 
 
-		c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote).
+		c.l.WithField("vpnIp", h.vpnAddrs[0]).WithField("udpAddr", h.remote).
 			Debug("Sending close tunnel message")
 			Debug("Sending close tunnel message")
 		closed++
 		closed++
 	}
 	}
@@ -246,7 +249,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Relays map
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
 	for _, relayingHost := range c.f.hostMap.Relays {
-		relayingHosts[relayingHost.vpnIp] = relayingHost
+		relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost
 	}
 	}
 	c.f.hostMap.Unlock()
 	c.f.hostMap.Unlock()
 
 
@@ -254,7 +257,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	// Grab the hostMap lock to access the Hosts map
 	// Grab the hostMap lock to access the Hosts map
 	c.f.hostMap.Lock()
 	c.f.hostMap.Lock()
 	for _, relayHost := range c.f.hostMap.Indexes {
 	for _, relayHost := range c.f.hostMap.Indexes {
-		if _, ok := relayingHosts[relayHost.vpnIp]; !ok {
+		if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok {
 			hostInfos = append(hostInfos, relayHost)
 			hostInfos = append(hostInfos, relayHost)
 		}
 		}
 	}
 	}
@@ -274,9 +277,8 @@ func (c *Control) Device() overlay.Device {
 }
 }
 
 
 func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
-
 	chi := ControlHostInfo{
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp,
+		VpnAddrs:               make([]netip.Addr, len(h.vpnAddrs)),
 		LocalIndex:             h.localIndexId,
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
@@ -285,6 +287,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 		CurrentRemote:          h.remote,
 		CurrentRemote:          h.remote,
 	}
 	}
 
 
+	for i, a := range h.vpnAddrs {
+		chi.VpnAddrs[i] = a
+	}
+
 	if h.ConnectionState != nil {
 	if h.ConnectionState != nil {
 		chi.MessageCounter = h.ConnectionState.messageCounter.Load()
 		chi.MessageCounter = h.ConnectionState.messageCounter.Load()
 	}
 	}
@@ -299,7 +305,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
 	hosts := make([]ControlHostInfo, 0)
 	hosts := make([]ControlHostInfo, 0)
 	pr := hl.GetPreferredRanges()
 	pr := hl.GetPreferredRanges()
-	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+	hl.ForEachVpnAddr(func(hostinfo *HostInfo) {
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 		hosts = append(hosts, copyHostInfo(hostinfo, pr))
 	})
 	})
 	return hosts
 	return hosts

+ 16 - 16
control_test.go

@@ -19,7 +19,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, netip.Prefix{})
+	hm := newHostMap(l)
 	hm.preferredRanges.Store(&[]netip.Prefix{})
 	hm.preferredRanges.Store(&[]netip.Prefix{})
 
 
 	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
 	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
@@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 	}
 
 
-	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
-	remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+	remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil)
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port()))
 
 
 	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
 	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
 	assert.True(t, ok)
 	assert.True(t, ok)
@@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		vpnIp:         vpnIp,
+		vpnAddrs:      []netip.Addr{vpnIp},
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}, &Interface{})
 	}, &Interface{})
 
 
@@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		vpnIp:         vpnIp2,
+		vpnAddrs:      []netip.Addr{vpnIp2},
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}, &Interface{})
 	}, &Interface{})
 
 
@@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 		l: logrus.New(),
 	}
 	}
 
 
-	thi := c.GetHostInfoByVpnIp(vpnIp, false)
+	thi := c.GetHostInfoByVpnAddr(vpnIp, false)
 
 
 	expectedInfo := ControlHostInfo{
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  vpnIp,
+		VpnAddrs:               []netip.Addr{vpnIp},
 		LocalIndex:             201,
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteIndex:            200,
 		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
 		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
@@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 	}
 
 
 	// Make sure we don't have any unexpected fields
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	assert.EqualValues(t, &expectedInfo, thi)
 	assert.EqualValues(t, &expectedInfo, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(vpnIp2, false)
+		thi = c.GetHostInfoByVpnAddr(vpnIp2, false)
 	})
 	})
 }
 }
 
 

+ 42 - 21
control_tester.go

@@ -6,8 +6,6 @@ package nebula
 import (
 import (
 	"net/netip"
 	"net/netip"
 
 
-	"github.com/slackhq/nebula/cert"
-
 	"github.com/google/gopacket"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
@@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
 
 
 	if toAddr.Addr().Is4() {
 	if toAddr.Addr().Is4() {
-		remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port()))
 	} else {
 	} else {
-		remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
+		remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port()))
 	}
 	}
 }
 }
 
 
@@ -67,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort)
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
+	remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp})
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
@@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 }
 
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
-	//TODO: IPV6-WORK
-	ip := layers.IPv4{
-		Version:  4,
-		TTL:      64,
-		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().Addr().Unmap().AsSlice(),
-		DstIP:    toIp.Unmap().AsSlice(),
+func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) {
+	serialize := make([]gopacket.SerializableLayer, 0)
+	var netLayer gopacket.NetworkLayer
+	if toAddr.Is6() {
+		if !fromAddr.Is6() {
+			panic("Cant send ipv6 to ipv4")
+		}
+		ip := &layers.IPv6{
+			Version:    6,
+			NextHeader: layers.IPProtocolUDP,
+			SrcIP:      fromAddr.Unmap().AsSlice(),
+			DstIP:      toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
+	} else {
+		if !fromAddr.Is4() {
+			panic("Cant send ipv4 to ipv6")
+		}
+
+		ip := &layers.IPv4{
+			Version:  4,
+			TTL:      64,
+			Protocol: layers.IPProtocolUDP,
+			SrcIP:    fromAddr.Unmap().AsSlice(),
+			DstIP:    toAddr.Unmap().AsSlice(),
+		}
+		serialize = append(serialize, ip)
+		netLayer = ip
 	}
 	}
 
 
 	udp := layers.UDP{
 	udp := layers.UDP{
 		SrcPort: layers.UDPPort(fromPort),
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 		DstPort: layers.UDPPort(toPort),
 	}
 	}
-	err := udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(netLayer)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 		ComputeChecksums: true,
 		ComputeChecksums: true,
 		FixLengths:       true,
 		FixLengths:       true,
 	}
 	}
-	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+
+	serialize = append(serialize, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, serialize...)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 }
 
 
-func (c *Control) GetVpnIp() netip.Addr {
-	return c.f.myVpnNet.Addr()
+func (c *Control) GetVpnAddrs() []netip.Addr {
+	return c.f.myVpnAddrs
 }
 }
 
 
 func (c *Control) GetUDPAddr() netip.AddrPort {
 func (c *Control) GetUDPAddr() netip.AddrPort {
@@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort {
 }
 }
 
 
 func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
 func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
+	hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		return false
 		return false
 	}
 	}
@@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap {
 	return c.f.hostMap
 	return c.f.hostMap
 }
 }
 
 
-func (c *Control) GetCert() cert.Certificate {
-	return c.f.pki.GetCertState().Certificate
+func (c *Control) GetCertState() *CertState {
+	return c.f.pki.getCertState()
 }
 }
 
 
 func (c *Control) ReHandshake(vpnIp netip.Addr) {
 func (c *Control) ReHandshake(vpnIp netip.Addr) {

+ 74 - 36
dns_server.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -21,24 +22,39 @@ var dnsAddr string
 
 
 type dnsRecords struct {
 type dnsRecords struct {
 	sync.RWMutex
 	sync.RWMutex
-	dnsMap  map[string]string
-	hostMap *HostMap
+	l               *logrus.Logger
+	dnsMap4         map[string]netip.Addr
+	dnsMap6         map[string]netip.Addr
+	hostMap         *HostMap
+	myVpnAddrsTable *bart.Table[struct{}]
 }
 }
 
 
-func newDnsRecords(hostMap *HostMap) *dnsRecords {
+func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
 	return &dnsRecords{
 	return &dnsRecords{
-		dnsMap:  make(map[string]string),
-		hostMap: hostMap,
+		l:               l,
+		dnsMap4:         make(map[string]netip.Addr),
+		dnsMap6:         make(map[string]netip.Addr),
+		hostMap:         hostMap,
+		myVpnAddrsTable: cs.myVpnAddrsTable,
 	}
 	}
 }
 }
 
 
-func (d *dnsRecords) Query(data string) string {
+func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
+	data = strings.ToLower(data)
 	d.RLock()
 	d.RLock()
 	defer d.RUnlock()
 	defer d.RUnlock()
-	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
-		return r
+	switch q {
+	case dns.TypeA:
+		if r, ok := d.dnsMap4[data]; ok {
+			return r
+		}
+	case dns.TypeAAAA:
+		if r, ok := d.dnsMap6[data]; ok {
+			return r
+		}
 	}
 	}
-	return ""
+
+	return netip.Addr{}
 }
 }
 
 
 func (d *dnsRecords) QueryCert(data string) string {
 func (d *dnsRecords) QueryCert(data string) string {
@@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 		return ""
 	}
 	}
 
 
-	hostinfo := d.hostMap.QueryVpnIp(ip)
+	hostinfo := d.hostMap.QueryVpnAddr(ip)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		return ""
 		return ""
 	}
 	}
@@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string {
 	return string(b)
 	return string(b)
 }
 }
 
 
-func (d *dnsRecords) Add(host, data string) {
+// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
+func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
+	host = strings.ToLower(host)
 	d.Lock()
 	d.Lock()
 	defer d.Unlock()
 	defer d.Unlock()
-	d.dnsMap[strings.ToLower(host)] = data
+	haveV4 := false
+	haveV6 := false
+	for _, addr := range addresses {
+		if addr.Is4() && !haveV4 {
+			d.dnsMap4[host] = addr
+			haveV4 = true
+		} else if addr.Is6() && !haveV6 {
+			d.dnsMap6[host] = addr
+			haveV6 = true
+		}
+		if haveV4 && haveV6 {
+			break
+		}
+	}
+}
+
+func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
+	a, _, _ := net.SplitHostPort(addr)
+	b, err := netip.ParseAddr(a)
+	if err != nil {
+		return false
+	}
+
+	if b.IsLoopback() {
+		return true
+	}
+
+	_, found := d.myVpnAddrsTable.Lookup(b)
+	return found //if we found it in this table, it's good
 }
 }
 
 
-func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
+func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
 	for _, q := range m.Question {
 	for _, q := range m.Question {
 		switch q.Qtype {
 		switch q.Qtype {
-		case dns.TypeA:
-			l.Debugf("Query for A %s", q.Name)
-			ip := dnsR.Query(q.Name)
-			if ip != "" {
-				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
+		case dns.TypeA, dns.TypeAAAA:
+			qType := dns.TypeToString[q.Qtype]
+			d.l.Debugf("Query for %s %s", qType, q.Name)
+			ip := d.Query(q.Qtype, q.Name)
+			if ip.IsValid() {
+				rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
 				if err == nil {
 				if err == nil {
 					m.Answer = append(m.Answer, rr)
 					m.Answer = append(m.Answer, rr)
 				}
 				}
 			}
 			}
 		case dns.TypeTXT:
 		case dns.TypeTXT:
-			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b, err := netip.ParseAddr(a)
-			if err != nil {
+			// We only answer these queries from nebula nodes or localhost
+			if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) {
 				return
 				return
 			}
 			}
-
-			// We don't answer these queries from non nebula nodes or localhost
-			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
-			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
-				return
-			}
-			l.Debugf("Query for TXT %s", q.Name)
-			ip := dnsR.QueryCert(q.Name)
+			d.l.Debugf("Query for TXT %s", q.Name)
+			ip := d.QueryCert(q.Name)
 			if ip != "" {
 			if ip != "" {
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
 				if err == nil {
 				if err == nil {
@@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 	}
 	}
 }
 }
 
 
-func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
+func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
 	m := new(dns.Msg)
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.SetReply(r)
 	m.Compress = false
 	m.Compress = false
 
 
 	switch r.Opcode {
 	switch r.Opcode {
 	case dns.OpcodeQuery:
 	case dns.OpcodeQuery:
-		parseQuery(l, m, w)
+		d.parseQuery(m, w)
 	}
 	}
 
 
 	w.WriteMsg(m)
 	w.WriteMsg(m)
 }
 }
 
 
-func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
-	dnsR = newDnsRecords(hostMap)
+func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
+	dnsR = newDnsRecords(l, cs, hostMap)
 
 
 	// attach request handler func
 	// attach request handler func
-	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
-		handleDnsRequest(l, w, r)
-	})
+	dns.HandleFunc(".", dnsR.handleDnsRequest)
 
 
 	c.RegisterReloadCallback(func(c *config.C) {
 	c.RegisterReloadCallback(func(c *config.C) {
 		reloadDns(l, c)
 		reloadDns(l, c)

+ 20 - 5
dns_server_test.go

@@ -1,23 +1,38 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"net/netip"
 	"testing"
 	"testing"
 
 
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestParsequery(t *testing.T) {
 func TestParsequery(t *testing.T) {
-	//TODO: This test is basically pointless
+	l := logrus.New()
 	hostMap := &HostMap{}
 	hostMap := &HostMap{}
-	ds := newDnsRecords(hostMap)
-	ds.Add("test.com.com", "1.2.3.4")
+	ds := newDnsRecords(l, &CertState{}, hostMap)
+	addrs := []netip.Addr{
+		netip.MustParseAddr("1.2.3.4"),
+		netip.MustParseAddr("1.2.3.5"),
+		netip.MustParseAddr("fd01::24"),
+		netip.MustParseAddr("fd01::25"),
+	}
+	ds.Add("test.com.com", addrs)
 
 
-	m := new(dns.Msg)
+	m := &dns.Msg{}
 	m.SetQuestion("test.com.com", dns.TypeA)
 	m.SetQuestion("test.com.com", dns.TypeA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String())
 
 
-	//parseQuery(m)
+	m = &dns.Msg{}
+	m.SetQuestion("test.com.com", dns.TypeAAAA)
+	ds.parseQuery(m, nil)
+	assert.NotNil(t, m.Answer)
+	assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String())
 }
 }
 
 
 func Test_getDnsServerAddr(t *testing.T) {
 func Test_getDnsServerAddr(t *testing.T) {

+ 208 - 171
e2e/handshakes_test.go

@@ -4,7 +4,6 @@
 package e2e
 package e2e
 
 
 import (
 import (
-	"fmt"
 	"net/netip"
 	"net/netip"
 	"slices"
 	"slices"
 	"testing"
 	"testing"
@@ -12,6 +11,7 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
@@ -21,11 +21,11 @@ import (
 
 
 func BenchmarkHotPath(b *testing.B) {
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
@@ -35,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) {
 	r.CancelFlowLogs()
 	r.CancelFlowLogs()
 
 
 	for n := 0; n < b.N; n++ {
 	for n := 0; n < b.N; n++ {
-		myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
 		_ = r.RouteForAllUntilTxTun(theirControl)
 	}
 	}
 
 
@@ -45,18 +45,18 @@ func BenchmarkHotPath(b *testing.B) {
 
 
 func TestGoodHandshake(t *testing.T) {
 func TestGoodHandshake(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -77,16 +77,16 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.WaitForType(1, 0, theirControl)
 	myControl.WaitForType(1, 0, theirControl)
 
 
 	t.Log("Make sure our host infos are correct")
 	t.Log("Make sure our host infos are correct")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
 
 
 	t.Log("Get that cached packet and make sure it looks right")
 	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 
 
 	t.Log("Do a bidirectional tunnel test")
 	t.Log("Do a bidirectional tunnel test")
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
 	defer r.RenderFlow()
 	defer r.RenderFlow()
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
 	myControl.Stop()
@@ -97,12 +97,12 @@ func TestGoodHandshake(t *testing.T) {
 func TestWrongResponderHandshake(t *testing.T) {
 func TestWrongResponderHandshake(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 
 
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
-	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
+	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
 
 
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl, evilControl)
 	r := router.NewR(t, myControl, theirControl, evilControl)
@@ -114,7 +114,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 	evilControl.Start()
 	evilControl.Start()
 
 
 	t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
 	t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	h := &header.H{}
 	h := &header.H{}
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -131,8 +131,8 @@ func TestWrongResponderHandshake(t *testing.T) {
 	})
 	})
 
 
 	t.Log("Evil tunnel is closed, inject the correct udp addr for them")
 	t.Log("Evil tunnel is closed, inject the correct udp addr for them")
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
 	assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
 	assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
 
 
 	t.Log("Route until we see the cached packet")
 	t.Log("Route until we see the cached packet")
@@ -153,18 +153,18 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 
 	t.Log("My cached packet should be received by them")
 	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 
 
 	t.Log("Test the tunnel with them")
 	t.Log("Test the tunnel with them")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	t.Log("Flush all packets from all controllers")
 	t.Log("Flush all packets from all controllers")
 	r.FlushAll()
 	r.FlushAll()
 
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
 
 
 	//TODO: assert hostmaps for everyone
 	//TODO: assert hostmaps for everyone
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
@@ -176,17 +176,17 @@ func TestWrongResponderHandshake(t *testing.T) {
 func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
 func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 
 
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
-	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil)
+	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil)
 	o := m{
 	o := m{
 		"static_host_map": m{
 		"static_host_map": m{
-			theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()},
+			theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()},
 		},
 		},
 	}
 	}
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o)
 
 
 	// Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr.
 	// Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr.
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl, evilControl)
 	r := router.NewR(t, myControl, theirControl, evilControl)
@@ -198,7 +198,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
 	evilControl.Start()
 	evilControl.Start()
 
 
 	t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
 	t.Log("Start the handshake process, we will route until we see the evil tunnel closed")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	h := &header.H{}
 	h := &header.H{}
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -215,8 +215,8 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
 	})
 	})
 
 
 	t.Log("Evil tunnel is closed, inject the correct udp addr for them")
 	t.Log("Evil tunnel is closed, inject the correct udp addr for them")
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true)
 	assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
 	assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr)
 
 
 	t.Log("Route until we see the cached packet")
 	t.Log("Route until we see the cached packet")
@@ -237,18 +237,19 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) {
 
 
 	t.Log("My cached packet should be received by them")
 	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 
 
 	t.Log("Test the tunnel with them")
 	t.Log("Test the tunnel with them")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	t.Log("Flush all packets from all controllers")
 	t.Log("Flush all packets from all controllers")
 	r.FlushAll()
 	r.FlushAll()
 
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil")
+	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 
 	//TODO: assert hostmaps for everyone
 	//TODO: assert hostmaps for everyone
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl)
@@ -262,12 +263,12 @@ func TestStage1Race(t *testing.T) {
 	// But will eventually collapse down to a single tunnel
 	// But will eventually collapse down to a single tunnel
 
 
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -278,8 +279,8 @@ func TestStage1Race(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake to start on both me and them")
 	t.Log("Trigger a handshake to start on both me and them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
 
 
 	t.Log("Get both stage 1 handshake packets")
 	t.Log("Get both stage 1 handshake packets")
 	myHsForThem := myControl.GetFromUDP(true)
 	myHsForThem := myControl.GetFromUDP(true)
@@ -291,14 +292,14 @@ func TestStage1Race(t *testing.T) {
 
 
 	r.Log("Route until they receive a message packet")
 	r.Log("Route until they receive a message packet")
 	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
 	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 
 
 	r.Log("Their cached packet should be received by me")
 	r.Log("Their cached packet should be received by me")
 	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
 	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
-	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
 
 
 	r.Log("Do a bidirectional tunnel test")
 	r.Log("Do a bidirectional tunnel test")
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	myHostmapHosts := myControl.ListHostmapHosts(false)
 	myHostmapHosts := myControl.ListHostmapHosts(false)
 	myHostmapIndexes := myControl.ListHostmapIndexes(false)
 	myHostmapIndexes := myControl.ListHostmapIndexes(false)
@@ -316,7 +317,7 @@ func TestStage1Race(t *testing.T) {
 	r.Log("Spin until connection manager tears down a tunnel")
 	r.Log("Spin until connection manager tears down a tunnel")
 
 
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -339,12 +340,12 @@ func TestStage1Race(t *testing.T) {
 
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
 func TestUncleanShutdownRaceLoser(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -355,10 +356,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	r.Log("Trigger a handshake from me to them")
 	r.Log("Trigger a handshake from me to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 
 
 	r.Log("Nuke my hostmap")
 	r.Log("Nuke my hostmap")
 	myHostmap := myControl.GetHostmap()
 	myHostmap := myControl.GetHostmap()
@@ -366,17 +367,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 
 
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))
 	p = r.RouteForAllUntilTxTun(theirControl)
 	p = r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	r.Log("Wait for the dead index to go away")
 	r.Log("Wait for the dead index to go away")
 	start := len(theirControl.GetHostmap().Indexes)
 	start := len(theirControl.GetHostmap().Indexes)
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		if len(theirControl.GetHostmap().Indexes) < start {
 		if len(theirControl.GetHostmap().Indexes) < start {
 			break
 			break
 		}
 		}
@@ -388,12 +389,12 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
 func TestUncleanShutdownRaceWinner(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -404,10 +405,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	r.Log("Trigger a handshake from me to them")
 	r.Log("Trigger a handshake from me to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 
 
 	r.Log("Nuke my hostmap")
 	r.Log("Nuke my hostmap")
@@ -416,18 +417,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 
 
-	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))
 	p = r.RouteForAllUntilTxTun(myControl)
 	p = r.RouteForAllUntilTxTun(myControl)
-	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80)
 	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	r.Log("Wait for the dead index to go away")
 	r.Log("Wait for the dead index to go away")
 	start := len(myControl.GetHostmap().Indexes)
 	start := len(myControl.GetHostmap().Indexes)
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		if len(myControl.GetHostmap().Indexes) < start {
 		if len(myControl.GetHostmap().Indexes) < start {
 			break
 			break
 		}
 		}
@@ -439,14 +440,14 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 
 
 func TestRelays(t *testing.T) {
 func TestRelays(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -458,11 +459,11 @@ func TestRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 }
 }
@@ -470,19 +471,19 @@ func TestRelays(t *testing.T) {
 func TestStage1RaceRelays(t *testing.T) {
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
-	theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
 
 
-	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
-	theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
 
 
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -494,14 +495,14 @@ func TestStage1RaceRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	r.Log("Get a tunnel between me and relay")
 	r.Log("Get a tunnel between me and relay")
-	assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
 
 
 	r.Log("Get a tunnel between them and relay")
 	r.Log("Get a tunnel between them and relay")
-	assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+	assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
 
 
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
 
 
 	r.Log("Wait for a packet from them to me")
 	r.Log("Wait for a packet from them to me")
 	p := r.RouteForAllUntilTxTun(myControl)
 	p := r.RouteForAllUntilTxTun(myControl)
@@ -519,20 +520,20 @@ func TestStage1RaceRelays(t *testing.T) {
 func TestStage1RaceRelays2(t *testing.T) {
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 	l := NewTestLogger()
 	l := NewTestLogger()
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
-	theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
 
 
-	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
-	theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
 
 
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -545,16 +546,16 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 	r.Log("Get a tunnel between me and relay")
 	r.Log("Get a tunnel between me and relay")
 	l.Info("Get a tunnel between me and relay")
 	l.Info("Get a tunnel between me and relay")
-	assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
 
 
 	r.Log("Get a tunnel between them and relay")
 	r.Log("Get a tunnel between them and relay")
 	l.Info("Get a tunnel between them and relay")
 	l.Info("Get a tunnel between them and relay")
-	assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+	assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
 
 
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	l.Info("Trigger a handshake from both them and me via relay to them and me")
 	l.Info("Trigger a handshake from both them and me via relay to them and me")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
 
 
 	//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -567,7 +568,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
-	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 
 
 	t.Log("Wait until we remove extra tunnels")
 	t.Log("Wait until we remove extra tunnels")
 	l.Info("Wait until we remove extra tunnels")
 	l.Info("Wait until we remove extra tunnels")
@@ -587,7 +588,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 				"theirControl": len(theirControl.GetHostmap().Indexes),
 				"theirControl": len(theirControl.GetHostmap().Indexes),
 				"relayControl": len(relayControl.GetHostmap().Indexes),
 				"relayControl": len(relayControl.GetHostmap().Indexes),
 			}).Info("Waiting for hostinfos to be removed...")
 			}).Info("Waiting for hostinfos to be removed...")
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 		retries--
 		retries--
@@ -595,7 +596,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
-	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 
 
 	myControl.Stop()
 	myControl.Stop()
 	theirControl.Stop()
 	theirControl.Stop()
@@ -607,14 +608,14 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 func TestRehandshakingRelays(t *testing.T) {
 func TestRehandshakingRelays(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -626,17 +627,17 @@ func TestRehandshakingRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 
 
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 
 	caB, err := ca.MarshalPEM()
 	caB, err := ca.MarshalPEM()
 	if err != nil {
 	if err != nil {
@@ -654,8 +655,8 @@ func TestRehandshakingRelays(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
-		c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
+		assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
 		if len(c.Cert.Groups()) != 0 {
 		if len(c.Cert.Groups()) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between my and relay is updated!")
 			r.Log("Certificate between my and relay is updated!")
@@ -667,8 +668,8 @@ func TestRehandshakingRelays(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
-		c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
 		if len(c.Cert.Groups()) != 0 {
 		if len(c.Cert.Groups()) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between their and relay is updated!")
 			r.Log("Certificate between their and relay is updated!")
@@ -679,13 +680,13 @@ func TestRehandshakingRelays(t *testing.T) {
 	}
 	}
 
 
 	r.Log("Assert the relay tunnel still works")
 	r.Log("Assert the relay tunnel still works")
-	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	// We should have two hostinfos on all sides
 	// We should have two hostinfos on all sides
 	for len(myControl.GetHostmap().Indexes) != 2 {
 	for len(myControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -693,7 +694,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -701,7 +702,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -711,14 +712,14 @@ func TestRehandshakingRelays(t *testing.T) {
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -730,17 +731,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 
 
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
 	r.Log("Renew relay certificate and spin until me and them sees it")
 	r.Log("Renew relay certificate and spin until me and them sees it")
-	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
 
 
 	caB, err := ca.MarshalPEM()
 	caB, err := ca.MarshalPEM()
 	if err != nil {
 	if err != nil {
@@ -758,8 +759,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
-		c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
+		assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
 		if len(c.Cert.Groups()) != 0 {
 		if len(c.Cert.Groups()) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between my and relay is updated!")
 			r.Log("Certificate between my and relay is updated!")
@@ -771,8 +772,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
-		c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false)
 		if len(c.Cert.Groups()) != 0 {
 		if len(c.Cert.Groups()) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between their and relay is updated!")
 			r.Log("Certificate between their and relay is updated!")
@@ -783,13 +784,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	}
 	}
 
 
 	r.Log("Assert the relay tunnel still works")
 	r.Log("Assert the relay tunnel still works")
-	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	// We should have two hostinfos on all sides
 	// We should have two hostinfos on all sides
 	for len(myControl.GetHostmap().Indexes) != 2 {
 	for len(myControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -797,7 +798,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -805,7 +806,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -814,12 +815,12 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 
 func TestRehandshaking(t *testing.T) {
 func TestRehandshaking(t *testing.T) {
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", "10.128.0.2/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me  ", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -830,12 +831,12 @@ func TestRehandshaking(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Stand up a tunnel between me and them")
 	t.Log("Stand up a tunnel between me and them")
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 
 	r.Log("Renew my certificate and spin until their sees it")
 	r.Log("Renew my certificate and spin until their sees it")
-	_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"})
+	_, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
 
 
 	caB, err := ca.MarshalPEM()
 	caB, err := ca.MarshalPEM()
 	if err != nil {
 	if err != nil {
@@ -852,8 +853,8 @@ func TestRehandshaking(t *testing.T) {
 	myConfig.ReloadConfigString(string(rc))
 	myConfig.ReloadConfigString(string(rc))
 
 
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
-		c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
 		if len(c.Cert.Groups()) != 0 {
 		if len(c.Cert.Groups()) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			break
 			break
@@ -880,19 +881,19 @@ func TestRehandshaking(t *testing.T) {
 
 
 	r.Log("Spin until there is only 1 tunnel")
 	r.Log("Spin until there is only 1 tunnel")
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
 
 
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 
 	// Make sure the correct tunnel won
 	// Make sure the correct tunnel won
-	c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
+	c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
 	assert.Contains(t, c.Cert.Groups(), "new group")
 	assert.Contains(t, c.Cert.Groups(), "new group")
 
 
 	// We should only have a single tunnel now on both sides
 	// We should only have a single tunnel now on both sides
@@ -911,12 +912,12 @@ func TestRehandshakingLoser(t *testing.T) {
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// Should be the one with the new certificate
 	// Should be the one with the new certificate
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", "10.128.0.2/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me  ", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -927,16 +928,12 @@ func TestRehandshakingLoser(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Stand up a tunnel between me and them")
 	t.Log("Stand up a tunnel between me and them")
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
-
-	tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
-	tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
-	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 
 	r.Log("Renew their certificate and spin until mine sees it")
 	r.Log("Renew their certificate and spin until mine sees it")
-	_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"})
+	_, _, theirNextPrivKey, theirNextPEM := NewTestCert(cert.Version1, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
 
 
 	caB, err := ca.MarshalPEM()
 	caB, err := ca.MarshalPEM()
 	if err != nil {
 	if err != nil {
@@ -953,8 +950,8 @@ func TestRehandshakingLoser(t *testing.T) {
 	theirConfig.ReloadConfigString(string(rc))
 	theirConfig.ReloadConfigString(string(rc))
 
 
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
-		theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
 
 
 		if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
 		if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
 			break
 			break
@@ -980,19 +977,19 @@ func TestRehandshakingLoser(t *testing.T) {
 
 
 	r.Log("Spin until there is only 1 tunnel")
 	r.Log("Spin until there is only 1 tunnel")
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
 
 
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 
 	// Make sure the correct tunnel won
 	// Make sure the correct tunnel won
-	theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+	theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
 	assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
 	assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
 
 
 	// We should only have a single tunnel now on both sides
 	// We should only have a single tunnel now on both sides
@@ -1011,12 +1008,12 @@ func TestRaceRegression(t *testing.T) {
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// caused a cross-linked hostinfo
 	// caused a cross-linked hostinfo
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
@@ -1030,8 +1027,8 @@ func TestRaceRegression(t *testing.T) {
 	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
 	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
 
 
 	t.Log("Start both handshakes")
 	t.Log("Start both handshakes")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))
 
 
 	t.Log("Get both stage 1")
 	t.Log("Get both stage 1")
 	myStage1ForThem := myControl.GetFromUDP(true)
 	myStage1ForThem := myControl.GetFromUDP(true)
@@ -1061,8 +1058,48 @@ func TestRaceRegression(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 
 	t.Log("Make sure the tunnel still works")
 	t.Log("Make sure the tunnel still works")
-	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestV2NonPrimaryWithLighthouse(t *testing.T) {
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh  ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}})
+
+	o := m{
+		"static_host_map": m{
+			lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()},
+		},
+		"lighthouse": m{
+			"hosts": []string{lhVpnIpNet[1].Addr().String()},
+			"local_allow_list": m{
+				// Try and block our lighthouse updates from using the actual addresses assigned to this computer
+				// If we start discovering addresses the test router doesn't know about then test traffic cant flow
+				"10.0.0.0/24": true,
+				"::/0":        false,
+			},
+		},
+	}
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me  ", "10.128.0.2/24, ff::2/64", o)
+	theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, lhControl, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	lhControl.Start()
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up an ipv6 tunnel between me and them")
+	assert.True(t, myVpnIpNet[1].Addr().Is6())
+	assert.True(t, theirVpnIpNet[1].Addr().Is6())
+	assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r)
 
 
+	lhControl.Stop()
 	myControl.Stop()
 	myControl.Stop()
 	theirControl.Stop()
 	theirControl.Stop()
 }
 }

+ 2 - 2
e2e/helpers.go

@@ -48,7 +48,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre
 
 
 // NewTestCert will generate a signed certificate with the provided details.
 // NewTestCert will generate a signed certificate with the provided details.
 // Expiry times are defaulted if you do not pass them in
 // Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
+func NewTestCert(v cert.Version, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
 	if before.IsZero() {
 	if before.IsZero() {
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
 	}
 	}
@@ -59,7 +59,7 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim
 
 
 	pub, rawPriv := x25519Keypair()
 	pub, rawPriv := x25519Keypair()
 	nc := &cert.TBSCertificate{
 	nc := &cert.TBSCertificate{
-		Version:        cert.Version1,
+		Version:        v,
 		Name:           name,
 		Name:           name,
 		Networks:       networks,
 		Networks:       networks,
 		UnsafeNetworks: unsafeNetworks,
 		UnsafeNetworks: unsafeNetworks,

+ 73 - 21
e2e/helpers_test.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"io"
 	"net/netip"
 	"net/netip"
 	"os"
 	"os"
+	"strings"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -26,25 +27,35 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 // newSimpleServer creates a nebula instance with many assumptions
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
+func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 	l := NewTestLogger()
 
 
-	vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
-	if err != nil {
-		panic(err)
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
 	}
 	}
 
 
 	var udpAddr netip.AddrPort
 	var udpAddr netip.AddrPort
-	if vpnIpNet.Addr().Is4() {
-		budpIp := vpnIpNet.Addr().As4()
+	if vpnNetworks[0].Addr().Is4() {
+		budpIp := vpnNetworks[0].Addr().As4()
 		budpIp[1] -= 128
 		budpIp[1] -= 128
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
 	} else {
 	} else {
-		budpIp := vpnIpNet.Addr().As16()
-		budpIp[13] -= 128
+		budpIp := vpnNetworks[0].Addr().As16()
+		// beef for funsies
+		budpIp[2] = 190
+		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
 	}
-	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{})
+	_, _, myPrivKey, myPEM := NewTestCert(v, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
 
 	caB, err := caCrt.MarshalPEM()
 	caB, err := caCrt.MarshalPEM()
 	if err != nil {
 	if err != nil {
@@ -88,11 +99,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
 	}
 	}
 
 
 	if overrides != nil {
 	if overrides != nil {
-		err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice)
+		final := m{}
+		err = mergo.Merge(&final, overrides, mergo.WithAppendSlice)
 		if err != nil {
 		if err != nil {
 			panic(err)
 			panic(err)
 		}
 		}
-		mc = overrides
+		err = mergo.Merge(&final, mc, mergo.WithAppendSlice)
+		if err != nil {
+			panic(err)
+		}
+		mc = final
 	}
 	}
 
 
 	cb, err := yaml.Marshal(mc)
 	cb, err := yaml.Marshal(mc)
@@ -109,7 +125,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	return control, vpnIpNet, udpAddr, c
+	return control, vpnNetworks, udpAddr, c
 }
 }
 
 
 type doneCb func()
 type doneCb func()
@@ -132,27 +148,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 
 
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	// Send a packet from them to me
-	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
+	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 
 	// And once more from me to them
 	// And once more from me to them
-	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
+	controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 }
 
 
-func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) {
 	// Get both host infos
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
-	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
+	//TODO: we may want to loop over each vpnAddr and assert all the things
+	hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false)
+	assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA")
 
 
-	hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
-	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
+	hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false)
+	assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB")
 
 
 	// Check that both vpn and real addr are correct
 	// Check that both vpn and real addr are correct
-	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
-	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
+	assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A")
+	assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B")
 
 
 	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
 	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
 	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
 	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
@@ -179,6 +196,33 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp
 }
 }
 
 
 func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	if toIp.Is6() {
+		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
+	} else {
+		assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort)
+	}
+}
+
+func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
+	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+	assert.NotNil(t, v6, "No ipv6 data found")
+
+	assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect")
+
+	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	assert.NotNil(t, udp, "No udp data found")
+
+	assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect")
+	assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect")
+
+	data := packet.ApplicationLayer()
+	assert.NotNil(t, data)
+	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
+}
+
+func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
 	assert.NotNil(t, v4, "No ipv4 data found")
@@ -197,6 +241,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 }
 
 
+func getAddrs(ns []netip.Prefix) []netip.Addr {
+	var a []netip.Addr
+	for _, n := range ns {
+		a = append(a, n.Addr())
+	}
+	return a
+}
+
 func NewTestLogger() *logrus.Logger {
 func NewTestLogger() *logrus.Logger {
 	l := logrus.New()
 	l := logrus.New()
 
 

+ 4 - 3
e2e/router/hostmap.go

@@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	var lines []string
 	var lines []string
 	var globalLines []*edge
 	var globalLines []*edge
 
 
-	clusterName := strings.Trim(c.GetCert().Name(), " ")
-	clusterVpnIp := c.GetCert().Networks()[0].Addr()
+	crt := c.GetCertState().GetDefaultCertificate()
+	clusterName := strings.Trim(crt.Name(), " ")
+	clusterVpnIp := crt.Networks()[0].Addr()
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 	r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
 
 
 	hm := c.GetHostmap()
 	hm := c.GetHostmap()
@@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	for _, idx := range indexes {
 	for _, idx := range indexes {
 		hi, ok := hm.Indexes[idx]
 		hi, ok := hm.Indexes[idx]
 		if ok {
 		if ok {
-			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
+			r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs())
 			remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
 			remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 			_ = hi
 			_ = hi

+ 44 - 22
e2e/router/router.go

@@ -10,8 +10,8 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"reflect"
 	"reflect"
+	"regexp"
 	"sort"
 	"sort"
-	"strings"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 			panic("Duplicate listen address: " + addr.String())
 			panic("Duplicate listen address: " + addr.String())
 		}
 		}
 
 
-		r.vpnControls[c.GetVpnIp()] = c
+		for _, vpnAddr := range c.GetVpnAddrs() {
+			r.vpnControls[vpnAddr] = c
+		}
+
 		r.controls[addr] = c
 		r.controls[addr] = c
 	}
 	}
 
 
@@ -213,11 +216,11 @@ func (r *R) renderFlow() {
 			continue
 			continue
 		}
 		}
 		participants[addr] = struct{}{}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr.String(), ":", "-", 1)
+		sanAddr := normalizeName(addr.String())
 		participantsVals = append(participantsVals, sanAddr)
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
-			sanAddr, e.packet.from.GetVpnIp(), sanAddr,
+			sanAddr, e.packet.from.GetVpnAddrs(), sanAddr,
 		)
 		)
 	}
 	}
 
 
@@ -250,9 +253,9 @@ func (r *R) renderFlow() {
 
 
 			fmt.Fprintf(f,
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.from.GetUDPAddr().String()),
 				line,
 				line,
-				strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
+				normalizeName(p.to.GetUDPAddr().String()),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 			)
 		}
 		}
@@ -267,6 +270,11 @@ func (r *R) renderFlow() {
 	}
 	}
 }
 }
 
 
+func normalizeName(s string) string {
+	rx := regexp.MustCompile("[\\[\\]\\:]")
+	return rx.ReplaceAllLiteralString(s, "_")
+}
+
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria.
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
 // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered
@@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
+		return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0
 	})
 	})
 
 
 	s := renderHostmaps(c...)
 	s := renderHostmaps(c...)
@@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 		// Nope, lets push the sender along
 		// Nope, lets push the sender along
 		case p := <-udpTx:
 		case p := <-udpTx:
 			r.Lock()
 			r.Lock()
-			c := r.getControl(sender.GetUDPAddr(), p.To, p)
+			a := sender.GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 			if c == nil {
 				r.Unlock()
 				r.Unlock()
-				panic("No control for udp tx")
+				panic("No control for udp tx " + a.String())
 			}
 			}
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			fp := r.unlockedInjectFlow(sender, c, p, false)
 			c.InjectUDPPacket(p)
 			c.InjectUDPPacket(p)
@@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 		} else {
 			// we are a udp tx, route and continue
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
 			p := rx.Interface().(*udp.Packet)
-			c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
+			a := cm[x].GetUDPAddr()
+			c := r.getControl(a, p.To, p)
 			if c == nil {
 			if c == nil {
 				r.Unlock()
 				r.Unlock()
-				panic("No control for udp tx")
+				panic(fmt.Sprintf("No control for udp tx %s", p.To))
 			}
 			}
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			fp := r.unlockedInjectFlow(cm[x], c, p, false)
 			c.InjectUDPPacket(p)
 			c.InjectUDPPacket(p)
@@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C
 }
 }
 
 
 func (r *R) formatUdpPacket(p *packet) string {
 func (r *R) formatUdpPacket(p *packet) string {
-	packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
-	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
-	if v4 == nil {
-		panic("not an ipv4 packet")
+	var packet gopacket.Packet
+	var srcAddr netip.Addr
+
+	packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy)
+	if packet.ErrorLayer() == nil {
+		v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
+	} else {
+		packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy)
+		v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
+		if v6 == nil {
+			panic("not an ipv6 packet")
+		}
+		srcAddr, _ = netip.AddrFromSlice(v6.SrcIP)
 	}
 	}
 
 
 	from := "unknown"
 	from := "unknown"
-	srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
 	if c, ok := r.vpnControls[srcAddr]; ok {
 	if c, ok := r.vpnControls[srcAddr]; ok {
 		from = c.GetUDPAddr().String()
 		from = c.GetUDPAddr().String()
 	}
 	}
 
 
-	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
-	if udp == nil {
+	udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
+	if udpLayer == nil {
 		panic("not a udp packet")
 		panic("not a udp packet")
 	}
 	}
 
 
 	data := packet.ApplicationLayer()
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
-		udp.SrcPort,
-		udp.DstPort,
+		normalizeName(from),
+		normalizeName(p.to.GetUDPAddr().String()),
+		udpLayer.SrcPort,
+		udpLayer.DstPort,
 		string(data.Payload()),
 		string(data.Payload()),
 	)
 	)
 }
 }

+ 13 - 4
examples/config.yml

@@ -13,6 +13,12 @@ pki:
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
   #disconnect_invalid: true
   #disconnect_invalid: true
 
 
+  # default_version controls which certificate version is used in handshakes.
+  # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`.
+  # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`.
+  # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed.
+  # default_version: 1
+
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.
 # The syntax is:
 # The syntax is:
@@ -336,10 +342,13 @@ firewall:
   #   host: `any` or a literal hostname, ie `test-host`
   #   host: `any` or a literal hostname, ie `test-host`
   #   group: `any` or a literal group name, ie `default-group`
   #   group: `any` or a literal group name, ie `default-group`
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
-  #   cidr: a remote CIDR, `0.0.0.0/0` is any.
-  #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
-  #      Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate
-  #      if `default_local_cidr_any` is false, otherwise its `any`.
+  #   cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. //TODO: we have a problem, firewall needs to understand this and should probably allow `any` for both
+  #   local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes.
+  #   //TODO: probably should have an `any` that covers both ip versions
+  #     If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network.
+  #     Otherwise the default is any vpn network assigned to via the certificate.
+  #     `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release.
+  #     If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation.
   #   ca_name: An issuing CA name
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
   #   ca_sha: An issuing CA shasum
 
 

+ 67 - 55
firewall.go

@@ -8,6 +8,7 @@ import (
 	"hash/fnv"
 	"hash/fnv"
 	"net/netip"
 	"net/netip"
 	"reflect"
 	"reflect"
+	"slices"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -22,7 +23,8 @@ import (
 )
 )
 
 
 type FirewallInterface interface {
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
+	//TODO: name these better addr, localAddr. Are they vpnAddrs?
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
 }
 }
 
 
 type conn struct {
 type conn struct {
@@ -51,9 +53,12 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 	DefaultTimeout time.Duration //linux: 600s
 
 
-	// Used to ensure we don't emit local packets for ips we don't own
-	localIps          *bart.Table[struct{}]
-	assignedCIDR      netip.Prefix
+	// routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate.
+	// The vpn addresses are a full bit match while the unsafe networks only match the prefix
+	routableNetworks *bart.Table[struct{}]
+
+	// assignedNetworks is a list of vpn networks assigned to us in the certificate.
+	assignedNetworks  []netip.Prefix
 	hasUnsafeNetworks bool
 	hasUnsafeNetworks bool
 
 
 	rules        string
 	rules        string
@@ -67,9 +72,9 @@ type Firewall struct {
 }
 }
 
 
 type firewallMetrics struct {
 type firewallMetrics struct {
-	droppedLocalIP  metrics.Counter
-	droppedRemoteIP metrics.Counter
-	droppedNoRule   metrics.Counter
+	droppedLocalAddr  metrics.Counter
+	droppedRemoteAddr metrics.Counter
+	droppedNoRule     metrics.Counter
 }
 }
 
 
 type FirewallConntrack struct {
 type FirewallConntrack struct {
@@ -126,84 +131,87 @@ type firewallLocalCIDR struct {
 }
 }
 
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
+// The certificate provided should be the highest version loaded in memory.
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
 	//TODO: error on 0 duration
 	//TODO: error on 0 duration
-	var min, max time.Duration
+	var tmin, tmax time.Duration
 
 
 	if tcpTimeout < UDPTimeout {
 	if tcpTimeout < UDPTimeout {
-		min = tcpTimeout
-		max = UDPTimeout
+		tmin = tcpTimeout
+		tmax = UDPTimeout
 	} else {
 	} else {
-		min = UDPTimeout
-		max = tcpTimeout
+		tmin = UDPTimeout
+		tmax = tcpTimeout
 	}
 	}
 
 
-	if defaultTimeout < min {
-		min = defaultTimeout
-	} else if defaultTimeout > max {
-		max = defaultTimeout
+	if defaultTimeout < tmin {
+		tmin = defaultTimeout
+	} else if defaultTimeout > tmax {
+		tmax = defaultTimeout
 	}
 	}
 
 
-	localIps := new(bart.Table[struct{}])
-	var assignedCIDR netip.Prefix
-	var assignedSet bool
+	routableNetworks := new(bart.Table[struct{}])
+	var assignedNetworks []netip.Prefix
 	for _, network := range c.Networks() {
 	for _, network := range c.Networks() {
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
 		nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
-		localIps.Insert(nprefix, struct{}{})
-
-		if !assignedSet {
-			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = nprefix
-			assignedSet = true
-		}
+		routableNetworks.Insert(nprefix, struct{}{})
+		assignedNetworks = append(assignedNetworks, network)
 	}
 	}
 
 
 	hasUnsafeNetworks := false
 	hasUnsafeNetworks := false
 	for _, n := range c.UnsafeNetworks() {
 	for _, n := range c.UnsafeNetworks() {
-		localIps.Insert(n, struct{}{})
+		routableNetworks.Insert(n, struct{}{})
 		hasUnsafeNetworks = true
 		hasUnsafeNetworks = true
 	}
 	}
 
 
 	return &Firewall{
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 		Conntrack: &FirewallConntrack{
 			Conns:      make(map[firewall.Packet]*conn),
 			Conns:      make(map[firewall.Packet]*conn),
-			TimerWheel: NewTimerWheel[firewall.Packet](min, max),
+			TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
 		},
 		},
 		InRules:           newFirewallTable(),
 		InRules:           newFirewallTable(),
 		OutRules:          newFirewallTable(),
 		OutRules:          newFirewallTable(),
 		TCPTimeout:        tcpTimeout,
 		TCPTimeout:        tcpTimeout,
 		UDPTimeout:        UDPTimeout,
 		UDPTimeout:        UDPTimeout,
 		DefaultTimeout:    defaultTimeout,
 		DefaultTimeout:    defaultTimeout,
-		localIps:          localIps,
-		assignedCIDR:      assignedCIDR,
+		routableNetworks:  routableNetworks,
+		assignedNetworks:  assignedNetworks,
 		hasUnsafeNetworks: hasUnsafeNetworks,
 		hasUnsafeNetworks: hasUnsafeNetworks,
 		l:                 l,
 		l:                 l,
 
 
 		incomingMetrics: firewallMetrics{
 		incomingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
 		},
 		},
 		outgoingMetrics: firewallMetrics{
 		outgoingMetrics: firewallMetrics{
-			droppedLocalIP:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
-			droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
-			droppedNoRule:   metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
+			droppedLocalAddr:  metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil),
+			droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil),
+			droppedNoRule:     metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
 		},
 		},
 	}
 	}
 }
 }
 
 
-func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) {
+func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
+	certificate := cs.getCertificate(cert.Version2)
+	if certificate == nil {
+		certificate = cs.getCertificate(cert.Version1)
+	}
+
+	if certificate == nil {
+		panic("No certificate available to reconfigure the firewall")
+	}
+
 	fw := NewFirewall(
 	fw := NewFirewall(
 		l,
 		l,
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
 		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
-		nc,
+		certificate,
 		//TODO: max_connections
 		//TODO: max_connections
 	)
 	)
 
 
-	//TODO: Flip to false after v1.9 release
-	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true)
+	fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
 
 
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	inboundAction := c.GetString("firewall.inbound_action", "drop")
 	switch inboundAction {
 	switch inboundAction {
@@ -283,7 +291,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		fp = ft.TCP
 		fp = ft.TCP
 	case firewall.ProtoUDP:
 	case firewall.ProtoUDP:
 		fp = ft.UDP
 		fp = ft.UDP
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		fp = ft.ICMP
 		fp = ft.ICMP
 	case firewall.ProtoAny:
 	case firewall.ProtoAny:
 		fp = ft.AnyProto
 		fp = ft.AnyProto
@@ -424,26 +432,25 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 	}
 
 
 	// Make sure remote address matches nebula certificate
 	// Make sure remote address matches nebula certificate
-	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-		_, ok := remoteCidr.Lookup(fp.RemoteIP)
+	if h.networks != nil {
+		_, ok := h.networks.Lookup(fp.RemoteAddr)
 		if !ok {
 		if !ok {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 			return ErrInvalidRemoteIP
 		}
 		}
 	} else {
 	} else {
 		// Simple case: Certificate has one IP and no subnets
 		// Simple case: Certificate has one IP and no subnets
-		if fp.RemoteIP != h.vpnIp {
-			f.metrics(incoming).droppedRemoteIP.Inc(1)
+		//TODO: we can make this more performant
+		if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
+			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 			return ErrInvalidRemoteIP
 		}
 		}
 	}
 	}
 
 
 	// Make sure we are supposed to be handling this local ip address
 	// Make sure we are supposed to be handling this local ip address
-	//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
-	_, ok := f.localIps.Lookup(fp.LocalIP)
+	_, ok := f.routableNetworks.Lookup(fp.LocalAddr)
 	if !ok {
 	if !ok {
-		f.metrics(incoming).droppedLocalIP.Inc(1)
+		f.metrics(incoming).droppedLocalAddr.Inc(1)
 		return ErrInvalidLocalIP
 		return ErrInvalidLocalIP
 	}
 	}
 
 
@@ -629,7 +636,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
 		if ft.UDP.match(p, incoming, c, caPool) {
 		if ft.UDP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
-	case firewall.ProtoICMP:
+	case firewall.ProtoICMP, firewall.ProtoICMPv6:
 		if ft.ICMP.match(p, incoming, c, caPool) {
 		if ft.ICMP.match(p, incoming, c, caPool) {
 			return true
 			return true
 		}
 		}
@@ -859,9 +866,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
 	}
 	}
 
 
 	matched := false
 	matched := false
-	prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
+	prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen())
 	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
 	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
-		if prefix.Contains(p.RemoteIP) && val.match(p, c) {
+		if prefix.Contains(p.RemoteAddr) && val.match(p, c) {
 			matched = true
 			matched = true
 			return false
 			return false
 		}
 		}
@@ -877,9 +884,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
 			return nil
 			return nil
 		}
 		}
 
 
-		localIp = f.assignedCIDR
+		for _, network := range f.assignedNetworks {
+			flc.LocalCIDR.Insert(network, struct{}{})
+		}
+		return nil
+
 	} else if localIp.Bits() == 0 {
 	} else if localIp.Bits() == 0 {
 		flc.Any = true
 		flc.Any = true
+		return nil
 	}
 	}
 
 
 	flc.LocalCIDR.Insert(localIp, struct{}{})
 	flc.LocalCIDR.Insert(localIp, struct{}{})
@@ -895,7 +907,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate
 		return true
 		return true
 	}
 	}
 
 
-	_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalAddr)
 	return ok
 	return ok
 }
 }
 
 

+ 11 - 10
firewall/packet.go

@@ -9,18 +9,19 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 const (
 const (
-	ProtoAny  = 0 // When we want to handle HOPOPT (0) we can change this, if ever
-	ProtoTCP  = 6
-	ProtoUDP  = 17
-	ProtoICMP = 1
+	ProtoAny    = 0 // When we want to handle HOPOPT (0) we can change this, if ever
+	ProtoTCP    = 6
+	ProtoUDP    = 17
+	ProtoICMP   = 1
+	ProtoICMPv6 = 58
 
 
 	PortAny      = 0  // Special value for matching `port: any`
 	PortAny      = 0  // Special value for matching `port: any`
 	PortFragment = -1 // Special value for matching `port: fragment`
 	PortFragment = -1 // Special value for matching `port: fragment`
 )
 )
 
 
 type Packet struct {
 type Packet struct {
-	LocalIP    netip.Addr
-	RemoteIP   netip.Addr
+	LocalAddr  netip.Addr
+	RemoteAddr netip.Addr
 	LocalPort  uint16
 	LocalPort  uint16
 	RemotePort uint16
 	RemotePort uint16
 	Protocol   uint8
 	Protocol   uint8
@@ -29,8 +30,8 @@ type Packet struct {
 
 
 func (fp *Packet) Copy() *Packet {
 func (fp *Packet) Copy() *Packet {
 	return &Packet{
 	return &Packet{
-		LocalIP:    fp.LocalIP,
-		RemoteIP:   fp.RemoteIP,
+		LocalAddr:  fp.LocalAddr,
+		RemoteAddr: fp.RemoteAddr,
 		LocalPort:  fp.LocalPort,
 		LocalPort:  fp.LocalPort,
 		RemotePort: fp.RemotePort,
 		RemotePort: fp.RemotePort,
 		Protocol:   fp.Protocol,
 		Protocol:   fp.Protocol,
@@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 		proto = fmt.Sprintf("unknown %v", fp.Protocol)
 	}
 	}
 	return json.Marshal(m{
 	return json.Marshal(m{
-		"LocalIP":    fp.LocalIP.String(),
-		"RemoteIP":   fp.RemoteIP.String(),
+		"LocalAddr":  fp.LocalAddr.String(),
+		"RemoteAddr": fp.RemoteAddr.String(),
 		"LocalPort":  fp.LocalPort,
 		"LocalPort":  fp.LocalPort,
 		"RemotePort": fp.RemotePort,
 		"RemotePort": fp.RemotePort,
 		"Protocol":   proto,
 		"Protocol":   proto,

+ 40 - 36
firewall_test.go

@@ -13,6 +13,7 @@ import (
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 )
 
 
 func TestNewFirewall(t *testing.T) {
 func TestNewFirewall(t *testing.T) {
@@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) {
 				InvertedGroups: map[string]struct{}{"default-group": {}},
 				InvertedGroups: map[string]struct{}{"default-group": {}},
 			},
 			},
 		},
 		},
-		vpnIp: netip.MustParseAddr("1.2.3.4"),
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
 	}
-	h.CreateRemoteCIDR(&c)
+	h.buildNetworks(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
 
 
 	// test remote mismatch
 	// test remote mismatch
-	oldRemote := p.RemoteIP
-	p.RemoteIP = netip.MustParseAddr("1.2.3.10")
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
-	p.RemoteIP = oldRemote
+	p.RemoteAddr = oldRemote
 
 
 	// ensure signer doesn't get in the way of group checks
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
@@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		}
 		}
 		ip := netip.MustParsePrefix("9.254.254.254/32")
 		ip := netip.MustParsePrefix("9.254.254.254/32")
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
@@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			InvertedGroups: map[string]struct{}{"nope": {}},
 			InvertedGroups: map[string]struct{}{"nope": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
@@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			InvertedGroups: map[string]struct{}{"good-group": {}},
 			InvertedGroups: map[string]struct{}{"good-group": {}},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
@@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h.CreateRemoteCIDR(c.Certificate)
+	h.buildNetworks(c.Certificate)
 
 
 	c1 := cert.CachedCertificate{
 	c1 := cert.CachedCertificate{
 		Certificate: &dummyCert{
 		Certificate: &dummyCert{
@@ -345,7 +346,7 @@ func TestFirewall_Drop2(t *testing.T) {
 			peerCert: &c1,
 			peerCert: &c1,
 		},
 		},
 	}
 	}
-	h1.CreateRemoteCIDR(c1.Certificate)
+	h1.buildNetworks(c1.Certificate)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -364,8 +365,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  1,
 		LocalPort:  1,
 		RemotePort: 1,
 		RemotePort: 1,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -391,9 +392,9 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 			peerCert: &c1,
 		},
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h1.CreateRemoteCIDR(c1.Certificate)
+	h1.buildNetworks(c1.Certificate)
 
 
 	c2 := cert.CachedCertificate{
 	c2 := cert.CachedCertificate{
 		Certificate: &dummyCert{
 		Certificate: &dummyCert{
@@ -406,9 +407,9 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 			peerCert: &c2,
 		},
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h2.CreateRemoteCIDR(c2.Certificate)
+	h2.buildNetworks(c2.Certificate)
 
 
 	c3 := cert.CachedCertificate{
 	c3 := cert.CachedCertificate{
 		Certificate: &dummyCert{
 		Certificate: &dummyCert{
@@ -421,9 +422,9 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 			peerCert: &c3,
 		},
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h3.CreateRemoteCIDR(c3.Certificate)
+	h3.buildNetworks(c3.Certificate)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -446,8 +447,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    netip.MustParseAddr("1.2.3.4"),
-		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
+		LocalAddr:  netip.MustParseAddr("1.2.3.4"),
+		RemoteAddr: netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -468,9 +469,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: network.Addr(),
+		vpnAddrs: []netip.Addr{network.Addr()},
 	}
 	}
-	h.CreateRemoteCIDR(c.Certificate)
+	h.buildNetworks(c.Certificate)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -622,55 +623,58 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	// Test a bad rule definition
 	// Test a bad rule definition
 	c := &dummyCert{}
 	c := &dummyCert{}
+	cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
+	require.NoError(t, err)
+
 	conf := config.NewC(l)
 	conf := config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
-	_, err := NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
 
 
 	// Test both port and code
 	// Test both port and code
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
 
 
 	// Test missing host, group, cidr, ca_name and ca_sha
 	// Test missing host, group, cidr, ca_name and ca_sha
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
 
 
 	// Test code/port error
 	// Test code/port error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
 
 
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
 
 
 	// Test proto error
 	// Test proto error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
 
 
 	// Test cidr parse error
 	// Test cidr parse error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 
 	// Test local_cidr parse error
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 
 	// Test both group and groups
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
-	_, err = NewFirewallFromConfig(l, c, conf)
+	_, err = NewFirewallFromConfig(l, cs, conf)
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
 }
 }
 
 

+ 0 - 1
go.mod

@@ -21,7 +21,6 @@ require (
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
 	github.com/sirupsen/logrus v1.9.3
 	github.com/sirupsen/logrus v1.9.3
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
-	github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
 	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6
 	github.com/stretchr/testify v1.9.0
 	github.com/stretchr/testify v1.9.0
 	github.com/vishvananda/netlink v1.3.0
 	github.com/vishvananda/netlink v1.3.0

+ 0 - 2
go.sum

@@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
-github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw=
 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

+ 162 - 80
handshake_ix.go

@@ -2,10 +2,12 @@ package nebula
 
 
 import (
 import (
 	"net/netip"
 	"net/netip"
+	"slices"
 	"time"
 	"time"
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 )
 )
 
 
@@ -16,30 +18,60 @@ import (
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	err := f.handshakeManager.allocateIndex(hh)
 	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 		return false
 	}
 	}
 
 
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
-	hh.hostinfo.ConnectionState = ci
+	// If we're connecting to a v6 address we must use a v2 cert
+	cs := f.pki.getCertState()
+	v := cs.defaultVersion
+	for _, a := range hh.hostinfo.vpnAddrs {
+		if a.Is6() {
+			v = cert.Version2
+			break
+		}
+	}
 
 
-	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hh.hostinfo.localIndexId,
-		Time:           uint64(time.Now().UnixNano()),
-		Cert:           certState.RawCertificateNoKey,
+	crt := cs.getCertificate(v)
+	if crt == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate is available")
+		return false
 	}
 	}
 
 
-	hsBytes := []byte{}
+	crtHs := cs.getHandshakeBytes(v)
+	if crtHs == nil {
+		f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", v).
+			Error("Failed to create connection state")
+		return false
+	}
+	hh.hostinfo.ConnectionState = ci
 
 
 	hs := &NebulaHandshake{
 	hs := &NebulaHandshake{
-		Details: hsProto,
+		Details: &NebulaHandshakeDetails{
+			InitiatorIndex: hh.hostinfo.localIndexId,
+			Time:           uint64(time.Now().UnixNano()),
+			Cert:           crtHs,
+			CertVersion:    uint32(v),
+		},
 	}
 	}
-	hsBytes, err = hs.Marshal()
 
 
+	hsBytes, err := hs.Marshal()
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 		return false
 	}
 	}
@@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 		return false
 	}
 	}
@@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 }
 }
 
 
 func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
-	certState := f.pki.GetCertState()
-	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
+	cs := f.pki.getCertState()
+	crt := cs.GetDefaultCertificate()
+	if crt == nil {
+		f.l.WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
+			WithField("certVersion", cs.defaultVersion).
+			Error("Unable to handshake with host because no certificate is available")
+	}
+
+	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
+	if err != nil {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to create connection state")
+		return
+	}
+
 	// Mark packet 1 as seen so it doesn't show up as missed
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 	ci.window.Update(f.l, 1)
 
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed to call noise.ReadMessage")
 		return
 		return
 	}
 	}
 
 
 	hs := &NebulaHandshake{}
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	err = hs.Unmarshal(msg)
-	/*
-		l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
-	*/
 	if err != nil || hs.Details == nil {
 	if err != nil || hs.Details == nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			Error("Failed unmarshal handshake message")
 		return
 		return
 	}
 	}
 
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
 	if err != nil {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 		return
 	}
 	}
 
 
+	if remoteCert.Certificate.Version() != ci.myCert.Version() {
+		// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
+		rc := cs.getCertificate(remoteCert.Certificate.Version())
+		if rc == nil {
+			f.l.WithError(err).WithField("udpAddr", addr).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
+				Info("Unable to handshake with host due to missing certificate version")
+			return
+		}
+
+		// Record the certificate we are actually using
+		ci.myCert = rc
+	}
+
 	if len(remoteCert.Certificate.Networks()) == 0 {
 	if len(remoteCert.Certificate.Networks()) == 0 {
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 		e := f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
@@ -111,30 +171,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 		return
 	}
 	}
 
 
-	vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
+	var vpnAddrs []netip.Addr
 	certName := remoteCert.Certificate.Name()
 	certName := remoteCert.Certificate.Name()
 	fingerprint := remoteCert.Fingerprint
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
 	issuer := remoteCert.Certificate.Issuer()
 
 
-	if vpnIp == f.myVpnNet.Addr() {
-		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
-			WithField("certName", certName).
-			WithField("fingerprint", fingerprint).
-			WithField("issuer", issuer).
-			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
-		return
-	}
-
-	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	for _, network := range remoteCert.Certificate.Networks() {
+		vpnAddr := network.Addr()
+		_, found := f.myVpnAddrsTable.Lookup(vpnAddr)
+		if found {
+			f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("fingerprint", fingerprint).
+				WithField("issuer", issuer).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself")
 			return
 			return
 		}
 		}
+
+		if addr.IsValid() {
+			if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
+				f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+				return
+			}
+		}
+
+		vpnAddrs = append(vpnAddrs, vpnAddr)
 	}
 	}
 
 
 	myIndex, err := generateIndex(f.l)
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -146,17 +212,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		ConnectionState:   ci,
 		ConnectionState:   ci,
 		localIndexId:      myIndex,
 		localIndexId:      myIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
 		remoteIndexId:     hs.Details.InitiatorIndex,
-		vpnIp:             vpnIp,
+		vpnAddrs:          vpnAddrs,
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}
 	}
 
 
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("issuer", issuer).
@@ -165,13 +231,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		Info("Handshake message received")
 		Info("Handshake message received")
 
 
 	hs.Details.ResponderIndex = myIndex
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = certState.RawCertificateNoKey
+	hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
+	if hs.Details.Cert == nil {
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("certVersion", ci.myCert.Version()).
+			Error("Unable to handshake with host because no certificate handshake bytes is available")
+		return
+	}
+
+	hs.Details.CertVersion = uint32(ci.myCert.Version())
 	// Update the time in case their clock is way off from ours
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
 
 	hsBytes, err := hs.Marshal()
 	hsBytes, err := hs.Marshal()
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -182,14 +261,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 		return
 	} else if dKey == nil || eKey == nil {
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -213,9 +292,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
 
-	hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
 	hostinfo.SetRemote(addr)
-	hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
+	hostinfo.buildNetworks(remoteCert.Certificate)
 
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
 	if err != nil {
@@ -225,7 +304,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				// the preferred remote
-				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
 			}
 
 
 			msg = existing.HandshakePacket[2]
 			msg = existing.HandshakePacket[2]
@@ -233,11 +312,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			if addr.IsValid() {
 			if addr.IsValid() {
 				err := f.outside.WriteTo(msg, addr)
 				err := f.outside.WriteTo(msg, addr)
 				if err != nil {
 				if err != nil {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 						WithError(err).Error("Failed to send handshake message")
 				} else {
 				} else {
-					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 						Info("Handshake message sent")
 				}
 				}
@@ -247,16 +326,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					f.l.Error("Handshake send failed: both addr and via are nil.")
 					return
 					return
 				}
 				}
-				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
 				f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-				f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp).
+				f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 					Info("Handshake message sent")
 					Info("Handshake message sent")
 				return
 				return
 			}
 			}
 		case ErrExistingHostInfo:
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
 				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
@@ -267,23 +346,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				Info("Handshake too old")
 				Info("Handshake too old")
 
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+			f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			return
 			return
 		case ErrLocalIndexCollision:
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
+				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs).
 				Error("Failed to add HostInfo due to localIndex collision")
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 			return
 		default:
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -299,7 +378,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	if addr.IsValid() {
 	if addr.IsValid() {
 		err = f.outside.WriteTo(msg, addr)
 		err = f.outside.WriteTo(msg, addr)
 		if err != nil {
 		if err != nil {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -307,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake")
 				WithError(err).Error("Failed to send handshake")
 		} else {
 		} else {
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
 				WithField("issuer", issuer).
 				WithField("issuer", issuer).
@@ -320,9 +399,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			f.l.Error("Handshake send failed: both addr and via are nil.")
 			return
 			return
 		}
 		}
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
 		f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
-		f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp).
+		f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
 			WithField("issuer", issuer).
 			WithField("issuer", issuer).
@@ -349,8 +428,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 
 	hostinfo := hh.hostinfo
 	hostinfo := hh.hostinfo
 	if addr.IsValid() {
 	if addr.IsValid() {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		//TODO: this is kind of nonsense now
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
+			f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 			return false
 		}
 		}
 	}
 	}
@@ -358,7 +438,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci := hostinfo.ConnectionState
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 			Error("Failed to call noise.ReadMessage")
 
 
@@ -367,7 +447,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		// near future
 		// near future
 		return false
 		return false
 	} else if dKey == nil || eKey == nil {
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 			Error("Noise did not arrive at a key")
 
 
@@ -379,16 +459,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hs := &NebulaHandshake{}
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
 		return true
 		return true
 	}
 	}
 
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
+	remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
 	if err != nil {
 	if err != nil {
-		e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 
 
 		if f.l.Level > logrus.DebugLevel {
 		if f.l.Level > logrus.DebugLevel {
@@ -413,7 +493,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		return true
 		return true
 	}
 	}
 
 
-	vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap()
+	vpnNetworks := remoteCert.Certificate.Networks()
 	certName := remoteCert.Certificate.Name()
 	certName := remoteCert.Certificate.Name()
 	fingerprint := remoteCert.Fingerprint
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
 	issuer := remoteCert.Certificate.Issuer()
@@ -430,12 +510,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	if addr.IsValid() {
 	if addr.IsValid() {
 		hostinfo.SetRemote(addr)
 		hostinfo.SetRemote(addr)
 	} else {
 	} else {
-		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
+		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
+	}
+
+	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
+	for i, n := range vpnNetworks {
+		vpnAddrs[i] = n.Addr()
 	}
 	}
 
 
 	// Ensure the right host responded
 	// Ensure the right host responded
-	if vpnIp != hostinfo.vpnIp {
-		f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp).
+	if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) {
+		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 			Info("Incorrect host responded to handshake")
@@ -444,14 +529,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 
 		// Create a new hostinfo/handshake for the intended vpn ip
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
+		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			//TODO: this doesnt know if its being added or is being used for caching a packet
 			//TODO: this doesnt know if its being added or is being used for caching a packet
 			// Block the current used address
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes = hostinfo.remotes
 			newHH.hostinfo.remotes.BlockRemote(addr)
 			newHH.hostinfo.remotes.BlockRemote(addr)
 
 
 			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
 			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
-				WithField("vpnIp", newHH.hostinfo.vpnIp).
+				WithField("vpnNetworks", vpnNetworks).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
 				Info("Blocked addresses for handshakes")
 				Info("Blocked addresses for handshakes")
 
 
@@ -459,11 +544,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			newHH.packetStore = hh.packetStore
 			newHH.packetStore = hh.packetStore
 			hh.packetStore = []*cachedPacket{}
 			hh.packetStore = []*cachedPacket{}
 
 
-			// Get the correct remote list for the host we did handshake with
-			hostinfo.SetRemote(addr)
-			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
-			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
-			hostinfo.vpnIp = vpnIp
+			// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
+			hostinfo.vpnAddrs = vpnAddrs
 			f.sendCloseTunnel(hostinfo)
 			f.sendCloseTunnel(hostinfo)
 		})
 		})
 
 
@@ -474,7 +556,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 	ci.window.Update(f.l, 2)
 
 
 	duration := time.Since(hh.startTime).Nanoseconds()
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
+	f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
 		WithField("issuer", issuer).
@@ -485,7 +567,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		Info("Handshake message received")
 		Info("Handshake message received")
 
 
 	// Build up the radix for the firewall if we have subnets in the cert
 	// Build up the radix for the firewall if we have subnets in the cert
-	hostinfo.CreateRemoteCIDR(remoteCert.Certificate)
+	hostinfo.buildNetworks(remoteCert.Certificate)
 
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	f.handshakeManager.Complete(hostinfo, f)
 	f.handshakeManager.Complete(hostinfo, f)

+ 135 - 87
handshake_manager.go

@@ -13,6 +13,7 @@ import (
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
@@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) Run(ctx context.Context) {
-	clockSource := time.NewTicker(c.config.tryInterval)
+func (hm *HandshakeManager) Run(ctx context.Context) {
+	clockSource := time.NewTicker(hm.config.tryInterval)
 	defer clockSource.Stop()
 	defer clockSource.Stop()
 
 
 	for {
 	for {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return
 			return
-		case vpnIP := <-c.trigger:
-			c.handleOutbound(vpnIP, true)
+		case vpnIP := <-hm.trigger:
+			hm.handleOutbound(vpnIP, true)
 		case now := <-clockSource.C:
 		case now := <-clockSource.C:
-			c.NextOutboundHandshakeTimerTick(now)
+			hm.NextOutboundHandshakeTimerTick(now)
 		}
 		}
 	}
 	}
 }
 }
@@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
-	c.OutboundHandshakeTimer.Advance(now)
+func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
+	hm.OutboundHandshakeTimer.Advance(now)
 	for {
 	for {
-		vpnIp, has := c.OutboundHandshakeTimer.Purge()
+		vpnIp, has := hm.OutboundHandshakeTimer.Purge()
 		if !has {
 		if !has {
 			break
 			break
 		}
 		}
-		c.handleOutbound(vpnIp, false)
+		hm.handleOutbound(vpnIp, false)
 	}
 	}
 }
 }
 
 
@@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp})
 	}
 	}
 
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
@@ -267,11 +268,18 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 		for _, relay := range hostinfo.remotes.relays {
-			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
+			// Don't relay to myself
+			if relay == vpnIp {
 				continue
 				continue
 			}
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+
+			// Don't relay through the host I'm trying to connect to
+			_, found := hm.f.myVpnAddrsTable.Lookup(relay)
+			if found {
+				continue
+			}
+
+			relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay)
 			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hm.f.Handshake(relay)
 				hm.f.Handshake(relay)
@@ -286,17 +294,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 				case Requested:
 				case Requested:
 					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 
 
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
-					// Re-send the CreateRelay request, in case the previous one was lost.
 					m := NebulaControl{
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					}
+
+					switch relayHostInfo.GetCert().Certificate.Version() {
+					case cert.Version1:
+						if !hm.f.myVpnAddrs[0].Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+							continue
+						}
+
+						if !vpnIp.Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+							continue
+						}
+
+						b := hm.f.myVpnAddrs[0].As4()
+						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+						b = vpnIp.As4()
+						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					case cert.Version2:
+						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					default:
+						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						continue
+					}
+
 					msg, err := m.Marshal()
 					msg, err := m.Marshal()
 					if err != nil {
 					if err != nil {
 						hostinfo.logger(hm.l).
 						hostinfo.logger(hm.l).
@@ -306,7 +332,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						// This must send over the hostinfo, not over hm.Hosts[ip]
 						// This must send over the hostinfo, not over hm.Hosts[ip]
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
+							"relayFrom":           hm.f.myVpnAddrs[0],
 							"relayTo":             vpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
 							"relay":               relay}).
 							"relay":               relay}).
@@ -316,7 +342,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					hostinfo.logger(hm.l).
 					hostinfo.logger(hm.l).
 						WithField("vpnIp", vpnIp).
 						WithField("vpnIp", vpnIp).
 						WithField("state", existingRelay.State).
 						WithField("state", existingRelay.State).
-						WithField("relay", relayHostInfo.vpnIp).
+						WithField("relay", relayHostInfo.vpnAddrs[0]).
 						Errorf("Relay unexpected state")
 						Errorf("Relay unexpected state")
 				}
 				}
 			} else {
 			} else {
@@ -327,16 +353,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 					}
 
 
-					//TODO: IPV6-WORK
-					myVpnIpB := hm.f.myVpnNet.Addr().As4()
-					theirVpnIpB := vpnIp.As4()
-
 					m := NebulaControl{
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
-						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					}
+
+					switch relayHostInfo.GetCert().Certificate.Version() {
+					case cert.Version1:
+						if !hm.f.myVpnAddrs[0].Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+							continue
+						}
+
+						if !vpnIp.Is4() {
+							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
+							continue
+						}
+
+						b := hm.f.myVpnAddrs[0].As4()
+						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+						b = vpnIp.As4()
+						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					case cert.Version2:
+						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					default:
+						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						continue
+					}
+
 					msg, err := m.Marshal()
 					msg, err := m.Marshal()
 					if err != nil {
 					if err != nil {
 						hostinfo.logger(hm.l).
 						hostinfo.logger(hm.l).
@@ -345,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					} else {
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnNet.Addr(),
+							"relayFrom":           hm.f.myVpnAddrs[0],
 							"relayTo":             vpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
 							"relay":               relay}).
@@ -381,10 +426,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 }
 }
 
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 	hm.Lock()
 
 
-	if hh, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnAddr]; ok {
 		// We are already trying to handshake with this vpn ip
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
 		if cacheCb != nil {
 			cacheCb(hh)
 			cacheCb(hh)
@@ -394,12 +439,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 	}
 	}
 
 
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
-		vpnIp:           vpnIp,
+		vpnAddrs:        []netip.Addr{vpnAddr},
 		HandshakePacket: make(map[uint8][]byte, 0),
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[netip.Addr]struct{}{},
-			relayForByIp:  map[netip.Addr]*Relay{},
-			relayForByIdx: map[uint32]*Relay{},
+			relays:         map[netip.Addr]struct{}{},
+			relayForByAddr: map[netip.Addr]*Relay{},
+			relayForByIdx:  map[uint32]*Relay{},
 		},
 		},
 	}
 	}
 
 
@@ -407,9 +452,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 		hostinfo:  hostinfo,
 		hostinfo:  hostinfo,
 		startTime: time.Now(),
 		startTime: time.Now(),
 	}
 	}
-	hm.vpnIps[vpnIp] = hh
+	hm.vpnIps[vpnAddr] = hh
 	hm.metricInitiated.Inc(1)
 	hm.metricInitiated.Inc(1)
-	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
+	hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval)
 
 
 	if cacheCb != nil {
 	if cacheCb != nil {
 		cacheCb(hh)
 		cacheCb(hh)
@@ -417,21 +462,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands
 
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// We can trigger the handshake right now
 	// We can trigger the handshake right now
-	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
+	_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr]
 	if !doTrigger {
 	if !doTrigger {
 		// Add any calculated remotes, and trigger early handshake if one found
 		// Add any calculated remotes, and trigger early handshake if one found
-		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
+		doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr)
 	}
 	}
 
 
 	if doTrigger {
 	if doTrigger {
 		select {
 		select {
-		case hm.trigger <- vpnIp:
+		case hm.trigger <- vpnAddr:
 		default:
 		default:
 		}
 		}
 	}
 	}
 
 
 	hm.Unlock()
 	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp)
+	hm.lightHouse.QueryServer(vpnAddr)
 	return hostinfo
 	return hostinfo
 }
 }
 
 
@@ -452,14 +497,14 @@ var (
 //
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.mainHostMap.Lock()
-	defer c.mainHostMap.Unlock()
-	c.Lock()
-	defer c.Unlock()
+func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
+	hm.mainHostMap.Lock()
+	defer hm.mainHostMap.Unlock()
+	hm.Lock()
+	defer hm.Unlock()
 
 
 	// Check if we already have a tunnel with this vpn ip
 	// Check if we already have a tunnel with this vpn ip
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]]
 	if found && existingHostInfo != nil {
 	if found && existingHostInfo != nil {
 		testHostInfo := existingHostInfo
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
 		for testHostInfo != nil {
@@ -476,31 +521,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			return existingHostInfo, ErrExistingHostInfo
 			return existingHostInfo, ErrExistingHostInfo
 		}
 		}
 
 
-		existingHostInfo.logger(c.l).Info("Taking new handshake")
+		existingHostInfo.logger(hm.l).Info("Taking new handshake")
 	}
 	}
 
 
-	existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId]
 	if found {
 	if found {
 		// We have a collision, but for a different hostinfo
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 		return existingIndex, ErrLocalIndexCollision
 	}
 	}
 
 
-	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	existingPendingIndex, found := hm.indexes[hostinfo.localIndexId]
 	if found && existingPendingIndex.hostinfo != hostinfo {
 	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		// We have a collision, but for a different hostinfo
 		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 		return existingPendingIndex.hostinfo, ErrLocalIndexCollision
 	}
 	}
 
 
-	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
-	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
+	existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
+	if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		// the remote ID. Just log about the situation as a note.
-		hostinfo.logger(c.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+		hostinfo.logger(hm.l).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
-	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
+	hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 	return existingHostInfo, nil
 }
 }
 
 
@@ -518,7 +563,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
 		// the remote ID. Just log about the situation as a note.
 		// the remote ID. Just log about the situation as a note.
 		hostinfo.logger(hm.l).
 		hostinfo.logger(hm.l).
-			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
+			WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs).
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
@@ -555,31 +600,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 	return errors.New("failed to generate unique localIndexId")
 }
 }
 
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	c.Lock()
-	defer c.Unlock()
-	c.unlockedDeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
 }
 }
 
 
-func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	delete(c.vpnIps, hostinfo.vpnIp)
-	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
+func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	for _, addr := range hostinfo.vpnAddrs {
+		delete(hm.vpnIps, addr)
 	}
 	}
 
 
-	delete(c.indexes, hostinfo.localIndexId)
-	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HandshakeHostInfo{}
+	if len(hm.vpnIps) == 0 {
+		hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 	}
 
 
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+	delete(hm.indexes, hostinfo.localIndexId)
+	if len(hm.indexes) == 0 {
+		hm.indexes = map[uint32]*HandshakeHostInfo{}
+	}
+
+	if hm.l.Level >= logrus.DebugLevel {
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps),
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Pending hostmap hostInfo deleted")
 			Debug("Pending hostmap hostInfo deleted")
 	}
 	}
 }
 }
 
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
+func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 	if hh != nil {
 		return hh.hostinfo
 		return hh.hostinfo
@@ -608,37 +656,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 	return hm.indexes[index]
 }
 }
 
 
-func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
-	return c.mainHostMap.GetPreferredRanges()
+func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix {
+	return hm.mainHostMap.GetPreferredRanges()
 }
 }
 
 
-func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
 
-	for _, v := range c.vpnIps {
+	for _, v := range hm.vpnIps {
 		f(v.hostinfo)
 		f(v.hostinfo)
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) ForEachIndex(f controlEach) {
-	c.RLock()
-	defer c.RUnlock()
+func (hm *HandshakeManager) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
 
 
-	for _, v := range c.indexes {
+	for _, v := range hm.indexes {
 		f(v.hostinfo)
 		f(v.hostinfo)
 	}
 	}
 }
 }
 
 
-func (c *HandshakeManager) EmitStats() {
-	c.RLock()
-	hostLen := len(c.vpnIps)
-	indexLen := len(c.indexes)
-	c.RUnlock()
+func (hm *HandshakeManager) EmitStats() {
+	hm.RLock()
+	hostLen := len(hm.vpnIps)
+	indexLen := len(hm.indexes)
+	hm.RUnlock()
 
 
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
 	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
-	c.mainHostMap.EmitStats()
+	hm.mainHostMap.EmitStats()
 }
 }
 
 
 // Utility functions below
 // Utility functions below

+ 19 - 11
handshake_manager_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
@@ -13,21 +14,20 @@ import (
 
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	localrange := netip.MustParsePrefix("10.1.1.1/24")
 	ip := netip.MustParseAddr("172.1.1.2")
 	ip := netip.MustParseAddr("172.1.1.2")
 
 
 	preferredRanges := []netip.Prefix{localrange}
 	preferredRanges := []netip.Prefix{localrange}
-	mainHM := newHostMap(l, vpncidr)
+	mainHM := newHostMap(l)
 	mainHM.preferredRanges.Store(&preferredRanges)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
 
 	lh := newTestLighthouse()
 	lh := newTestLighthouse()
 
 
 	cs := &CertState{
 	cs := &CertState{
-		RawCertificate:      []byte{},
-		PrivateKey:          []byte{},
-		Certificate:         &dummyCert{},
-		RawCertificateNoKey: []byte{},
+		defaultVersion:   cert.Version1,
+		privateKey:       []byte{},
+		v1Cert:           &dummyCert{version: cert.Version1},
+		v1HandshakeBytes: []byte{},
 	}
 	}
 
 
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	i2 := blah.StartHandshake(ip, nil)
 	i2 := blah.StartHandshake(ip, nil)
 	assert.Same(t, i, i2)
 	assert.Same(t, i, i2)
 
 
-	i.remotes = NewRemoteList(nil)
+	i.remotes = NewRemoteList([]netip.Addr{}, nil)
 
 
 	// Adding something to pending should not affect the main hostmap
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
 	assert.Len(t, mainHM.Hosts, 0)
@@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 type mockEncWriter struct {
 type mockEncWriter struct {
 }
 }
 
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
+func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
+func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
+
+func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
+	return nil
+}
+
+func (mw *mockEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: cert.Version2}
+}

+ 95 - 86
hostmap.go

@@ -48,7 +48,7 @@ type Relay struct {
 	State       int
 	State       int
 	LocalIndex  uint32
 	LocalIndex  uint32
 	RemoteIndex uint32
 	RemoteIndex uint32
-	PeerIp      netip.Addr
+	PeerAddr    netip.Addr
 }
 }
 
 
 type HostMap struct {
 type HostMap struct {
@@ -58,7 +58,6 @@ type HostMap struct {
 	RemoteIndexes   map[uint32]*HostInfo
 	RemoteIndexes   map[uint32]*HostInfo
 	Hosts           map[netip.Addr]*HostInfo
 	Hosts           map[netip.Addr]*HostInfo
 	preferredRanges atomic.Pointer[[]netip.Prefix]
 	preferredRanges atomic.Pointer[[]netip.Prefix]
-	vpnCIDR         netip.Prefix
 	l               *logrus.Logger
 	l               *logrus.Logger
 }
 }
 
 
@@ -68,9 +67,9 @@ type HostMap struct {
 type RelayState struct {
 type RelayState struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
-	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
+	relays         map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer
+	relayForByAddr map[netip.Addr]*Relay   // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx  map[uint32]*Relay       // Maps a local index to some Relay info
 }
 }
 
 
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 func (rs *RelayState) DeleteRelay(ip netip.Addr) {
@@ -89,10 +88,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 	return ret
 }
 }
 
 
-func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[ip]
+	r, ok := rs.relayForByAddr[addr]
 	return r, ok
 	return r, ok
 }
 }
 
 
@@ -115,8 +114,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr {
 func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
-	for relayIp := range rs.relayForByIp {
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr))
+	for relayIp := range rs.relayForByAddr {
 		currentRelays = append(currentRelays, relayIp)
 		currentRelays = append(currentRelays, relayIp)
 	}
 	}
 	return currentRelays
 	return currentRelays
@@ -135,7 +134,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	if !ok {
 	if !ok {
 		return false
 		return false
 	}
 	}
@@ -143,7 +142,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool
 	newRelay.State = Established
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return true
 	return true
 }
 }
 
 
@@ -158,14 +157,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	newRelay.State = Established
 	newRelay.State = Established
 	newRelay.RemoteIndex = remoteIdx
 	newRelay.RemoteIndex = remoteIdx
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
 	rs.relayForByIdx[r.LocalIndex] = &newRelay
-	rs.relayForByIp[r.PeerIp] = &newRelay
+	rs.relayForByAddr[r.PeerAddr] = &newRelay
 	return &newRelay, true
 	return &newRelay, true
 }
 }
 
 
 func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	r, ok := rs.relayForByIp[vpnIp]
+	r, ok := rs.relayForByAddr[vpnIp]
 	return r, ok
 	return r, ok
 }
 }
 
 
@@ -179,7 +178,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
-	rs.relayForByIp[ip] = r
+	rs.relayForByAddr[ip] = r
 	rs.relayForByIdx[idx] = r
 	rs.relayForByIdx[idx] = r
 }
 }
 
 
@@ -190,10 +189,12 @@ type HostInfo struct {
 	ConnectionState *ConnectionState
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	remoteIndexId   uint32
 	localIndexId    uint32
 	localIndexId    uint32
-	vpnIp           netip.Addr
+	vpnAddrs        []netip.Addr
 	recvError       atomic.Uint32
 	recvError       atomic.Uint32
-	remoteCidr      *bart.Table[struct{}]
-	relayState      RelayState
+
+	// networks are both all vpn and unsafe networks assigned to this host
+	networks   *bart.Table[struct{}]
+	relayState RelayState
 
 
 	// HandshakePacket records the packets used to create this hostinfo
 	// HandshakePacket records the packets used to create this hostinfo
 	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
 	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
@@ -241,28 +242,26 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 	dropped metrics.Counter
 }
 }
 
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
-	hm := newHostMap(l, vpnCIDR)
+func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap {
+	hm := newHostMap(l)
 
 
 	hm.reload(c, true)
 	hm.reload(c, true)
 	c.RegisterReloadCallback(func(c *config.C) {
 	c.RegisterReloadCallback(func(c *config.C) {
 		hm.reload(c, false)
 		hm.reload(c, false)
 	})
 	})
 
 
-	l.WithField("network", hm.vpnCIDR.String()).
-		WithField("preferredRanges", hm.GetPreferredRanges()).
+	l.WithField("preferredRanges", hm.GetPreferredRanges()).
 		Info("Main HostMap created")
 		Info("Main HostMap created")
 
 
 	return hm
 	return hm
 }
 }
 
 
-func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
+func newHostMap(l *logrus.Logger) *HostMap {
 	return &HostMap{
 	return &HostMap{
 		Indexes:       map[uint32]*HostInfo{},
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		Hosts:         map[netip.Addr]*HostInfo{},
 		Hosts:         map[netip.Addr]*HostInfo{},
-		vpnCIDR:       vpnCIDR,
 		l:             l,
 		l:             l,
 	}
 	}
 }
 }
@@ -305,17 +304,6 @@ func (hm *HostMap) EmitStats() {
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 }
 
 
-func (hm *HostMap) RemoveRelay(localIdx uint32) {
-	hm.Lock()
-	_, ok := hm.Relays[localIdx]
-	if !ok {
-		hm.Unlock()
-		return
-	}
-	delete(hm.Relays, localIdx)
-	hm.Unlock()
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
 	// Delete the host itself, ensuring it's not modified anymore
@@ -335,7 +323,9 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 }
 }
 
 
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
 func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
-	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	//TODO: we may need to promote follow on hostinfos from these vpnAddrs as well since their oldHostinfo might not be the same as this one
+	// this really looks like an ideal spot for memory leaks
+	oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]]
 	if oldHostinfo == hostinfo {
 	if oldHostinfo == hostinfo {
 		return
 		return
 	}
 	}
@@ -348,7 +338,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
 		hostinfo.next.prev = hostinfo.prev
 		hostinfo.next.prev = hostinfo.prev
 	}
 	}
 
 
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
+	hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo
 
 
 	if oldHostinfo == nil {
 	if oldHostinfo == nil {
 		return
 		return
@@ -360,23 +350,35 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
 }
 }
 
 
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	for _, addr := range hostinfo.vpnAddrs {
+		h := hm.Hosts[addr]
+		for h != nil {
+			if h == hostinfo {
+				hm.unlockedInnerDeleteHostInfo(h, addr)
+			}
+			h = h.next
+		}
+	}
+}
+
+func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
+	primary, ok := hm.Hosts[addr]
 	if ok && primary == hostinfo {
 	if ok && primary == hostinfo {
-		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
-		delete(hm.Hosts, hostinfo.vpnIp)
+		// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, addr)
 		if len(hm.Hosts) == 0 {
 		if len(hm.Hosts) == 0 {
 			hm.Hosts = map[netip.Addr]*HostInfo{}
 			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 		}
 
 
 		if hostinfo.next != nil {
 		if hostinfo.next != nil {
-			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
-			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
+			hm.Hosts[addr] = hostinfo.next
 			// It is primary, there is no previous hostinfo now
 			// It is primary, there is no previous hostinfo now
 			hostinfo.next.prev = nil
 			hostinfo.next.prev = nil
 		}
 		}
 
 
 	} else {
 	} else {
-		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		// Relink if we were in the middle of multiple hostinfos for this vpn addr
 		if hostinfo.prev != nil {
 		if hostinfo.prev != nil {
 			hostinfo.prev.next = hostinfo.next
 			hostinfo.prev.next = hostinfo.next
 		}
 		}
@@ -406,7 +408,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
 		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
-			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			"vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 			Debug("Hostmap hostInfo deleted")
 	}
 	}
 
 
@@ -448,11 +450,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
-	return hm.queryVpnIp(vpnIp, nil)
+func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
+	return hm.queryVpnAddr(vpnIp, nil)
 }
 }
 
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 
 
@@ -460,17 +462,21 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn
 	if !ok {
 	if !ok {
 		return nil, nil, errors.New("unable to find host")
 		return nil, nil, errors.New("unable to find host")
 	}
 	}
+
 	for h != nil {
 	for h != nil {
-		r, ok := h.relayState.QueryRelayForByIp(targetIp)
-		if ok && r.State == Established {
-			return h, r, nil
+		for _, targetIp := range targetIps {
+			r, ok := h.relayState.QueryRelayForByIp(targetIp)
+			if ok && r.State == Established {
+				return h, r, nil
+			}
 		}
 		}
 		h = h.next
 		h = h.next
 	}
 	}
+
 	return nil, nil, errors.New("unable to find host with relay")
 	return nil, nil, errors.New("unable to find host with relay")
 }
 }
 
 
-func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
 		hm.RUnlock()
@@ -491,25 +497,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
 		remoteCert := hostinfo.ConnectionState.peerCert
-		dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String())
+		dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
 	}
 	}
-
-	existing := hm.Hosts[hostinfo.vpnIp]
-	hm.Hosts[hostinfo.vpnIp] = hostinfo
-
-	if existing != nil {
-		hostinfo.next = existing
-		existing.prev = hostinfo
+	for _, addr := range hostinfo.vpnAddrs {
+		hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
 	}
 	}
 
 
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
-			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
+		hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts),
+			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}).
 			Debug("Hostmap vpnIp added")
 			Debug("Hostmap vpnIp added")
 	}
 	}
+}
+
+func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) {
+	existing := hm.Hosts[vpnAddr]
+	hm.Hosts[vpnAddr] = hostinfo
+
+	if existing != nil && existing != hostinfo {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
 
 
 	i := 1
 	i := 1
 	check := hostinfo
 	check := hostinfo
@@ -527,7 +538,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	return *hm.preferredRanges.Load()
 	return *hm.preferredRanges.Load()
 }
 }
 
 
-func (hm *HostMap) ForEachVpnIp(f controlEach) {
+func (hm *HostMap) ForEachVpnAddr(f controlEach) {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 
 
@@ -581,7 +592,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac
 		}
 		}
 
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp)
+		ifce.lightHouse.QueryServer(i.vpnAddrs[0])
 	}
 	}
 }
 }
 
 
@@ -596,7 +607,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	// We copy here because we likely got this remote from a source that reuses the object
 	if i.remote != remote {
 	if i.remote != remote {
 		i.remote = remote
 		i.remote = remote
-		i.remotes.LearnRemote(i.vpnIp, remote)
+		i.remotes.LearnRemote(i.vpnAddrs[0], remote)
 	}
 	}
 }
 }
 
 
@@ -647,21 +658,20 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 	return true
 }
 }
 
 
-func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) {
+func (i *HostInfo) buildNetworks(c cert.Certificate) {
 	if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
 	if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
 		// Simple case, no CIDRTree needed
 		// Simple case, no CIDRTree needed
 		return
 		return
 	}
 	}
 
 
-	remoteCidr := new(bart.Table[struct{}])
+	i.networks = new(bart.Table[struct{}])
 	for _, network := range c.Networks() {
 	for _, network := range c.Networks() {
-		remoteCidr.Insert(network, struct{}{})
+		i.networks.Insert(network, struct{}{})
 	}
 	}
 
 
 	for _, network := range c.UnsafeNetworks() {
 	for _, network := range c.UnsafeNetworks() {
-		remoteCidr.Insert(network, struct{}{})
+		i.networks.Insert(network, struct{}{})
 	}
 	}
-	i.remoteCidr = remoteCidr
 }
 }
 
 
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
@@ -669,7 +679,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 		return logrus.NewEntry(l)
 		return logrus.NewEntry(l)
 	}
 	}
 
 
-	li := l.WithField("vpnIp", i.vpnIp).
+	li := l.WithField("vpnAddrs", i.vpnAddrs).
 		WithField("localIndex", i.localIndexId).
 		WithField("localIndex", i.localIndexId).
 		WithField("remoteIndex", i.remoteIndexId)
 		WithField("remoteIndex", i.remoteIndexId)
 
 
@@ -684,9 +694,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 
 // Utility functions
 // Utility functions
 
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
+func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
 	//FIXME: This function is pretty garbage
-	var ips []netip.Addr
+	var finalAddrs []netip.Addr
 	ifaces, _ := net.Interfaces()
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
 		allow := allowList.AllowName(i.Name)
@@ -698,39 +708,38 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 			continue
 			continue
 		}
 		}
 		addrs, _ := i.Addrs()
 		addrs, _ := i.Addrs()
-		for _, addr := range addrs {
-			var ip net.IP
-			switch v := addr.(type) {
+		for _, rawAddr := range addrs {
+			var addr netip.Addr
+			switch v := rawAddr.(type) {
 			case *net.IPNet:
 			case *net.IPNet:
 				//continue
 				//continue
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			case *net.IPAddr:
 			case *net.IPAddr:
-				ip = v.IP
+				addr, _ = netip.AddrFromSlice(v.IP)
 			}
 			}
 
 
-			nip, ok := netip.AddrFromSlice(ip)
-			if !ok {
+			if !addr.IsValid() {
 				if l.Level >= logrus.DebugLevel {
 				if l.Level >= logrus.DebugLevel {
-					l.WithField("localIp", ip).Debug("ip was invalid for netip")
+					l.WithField("localAddr", rawAddr).Debug("addr was invalid")
 				}
 				}
 				continue
 				continue
 			}
 			}
-			nip = nip.Unmap()
+			addr = addr.Unmap()
 
 
 			//TODO: Filtering out link local for now, this is probably the most correct thing
 			//TODO: Filtering out link local for now, this is probably the most correct thing
 			//TODO: Would be nice to filter out SLAAC MAC based ips as well
 			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
-				allow := allowList.Allow(nip)
+			if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false {
+				isAllowed := allowList.Allow(addr)
 				if l.Level >= logrus.TraceLevel {
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow")
 				}
 				}
-				if !allow {
+				if !isAllowed {
 					continue
 					continue
 				}
 				}
 
 
-				ips = append(ips, nip)
+				finalAddrs = append(finalAddrs, addr)
 			}
 			}
 		}
 		}
 	}
 	}
-	return ips
+	return finalAddrs
 }
 }

+ 23 - 33
hostmap_test.go

@@ -11,17 +11,14 @@ import (
 
 
 func TestHostMap_MakePrimary(t *testing.T) {
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 
 	f := &Interface{}
 	f := &Interface{}
 
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
 
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 	hm.unlockedAddHostInfo(h1, f)
 
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 	hm.MakePrimary(h3)
 
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 	hm.MakePrimary(h4)
 
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 	hm.MakePrimary(h4)
 
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) {
 
 
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	hm := newHostMap(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-	)
+	hm := newHostMap(l)
 
 
 	f := &Interface{}
 	f := &Interface{}
 
 
-	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
-	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
-	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
-	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
-	h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
-	h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
+	h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1}
+	h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2}
+	h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3}
+	h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4}
+	h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5}
+	h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6}
 
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 	assert.Nil(t, h)
 
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 	assert.Nil(t, h1.next)
 
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 	assert.Nil(t, h3.next)
 
 
 	// Make sure we go h2 -> h4 -> h5
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 	assert.Nil(t, h5.next)
 
 
 	// Make sure we go h2 -> h4
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 	assert.Nil(t, h2.next)
 
 
 	// Make sure we only have h4
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
 	assert.Nil(t, prim.next)
@@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 	assert.Nil(t, h4.next)
 
 
 	// Make sure we have nil
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
+	prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 	assert.Nil(t, prim)
 }
 }
 
 
@@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
 
 
-	hm := NewHostMapFromConfig(
-		l,
-		netip.MustParsePrefix("10.0.0.1/24"),
-		c,
-	)
+	hm := NewHostMapFromConfig(l, c)
 
 
 	toS := func(ipn []netip.Prefix) []string {
 	toS := func(ipn []netip.Prefix) []string {
 		var s []string
 		var s []string

+ 2 - 2
hostmap_tester.go

@@ -9,8 +9,8 @@ import (
 	"net/netip"
 	"net/netip"
 )
 )
 
 
-func (i *HostInfo) GetVpnIp() netip.Addr {
-	return i.vpnIp
+func (i *HostInfo) GetVpnAddrs() []netip.Addr {
+	return i.vpnAddrs
 }
 }
 
 
 func (i *HostInfo) GetLocalIndex() uint32 {
 func (i *HostInfo) GetLocalIndex() uint32 {

+ 32 - 26
inside.go

@@ -20,14 +20,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 	}
 
 
 	// Ignore local broadcast packets
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
-		return
+	if f.dropLocalBroadcast {
+		_, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr)
+		if found {
+			return
+		}
 	}
 	}
 
 
-	if fwPacket.RemoteIP == f.myVpnNet.Addr() {
+	//TODO: seems like a huge bummer
+	_, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr)
+	if found {
 		// Immediately forward packets from self to self.
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// This should only happen on Darwin-based and FreeBSD hosts, which
-		// routes packets from the Nebula IP to the Nebula IP through the Nebula
+		// routes packets from the Nebula addr to the Nebula addr through the Nebula
 		// TUN device.
 		// TUN device.
 		if immediatelyForwardToSelf {
 		if immediatelyForwardToSelf {
 			_, err := f.readers[q].Write(packet)
 			_, err := f.readers[q].Write(packet)
@@ -36,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 			}
 			}
 		}
 		}
 		// Otherwise, drop. On linux, we should never see these packets - Linux
 		// Otherwise, drop. On linux, we should never see these packets - Linux
-		// routes packets from the nebula IP to the nebula IP through the loopback device.
+		// routes packets from the nebula addr to the nebula addr through the loopback device.
 		return
 		return
 	}
 	}
 
 
 	// Ignore multicast packets
 	// Ignore multicast packets
-	if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
+	if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() {
 		return
 		return
 	}
 	}
 
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 	})
 
 
 	if hostinfo == nil {
 	if hostinfo == nil {
 		f.rejectInside(packet, out, q)
 		f.rejectInside(packet, out, q)
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", fwPacket.RemoteIP).
+			f.l.WithField("vpnAddr", fwPacket.RemoteAddr).
 				WithField("fwPacket", fwPacket).
 				WithField("fwPacket", fwPacket).
-				Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes")
+				Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks")
 		}
 		}
 		return
 		return
 	}
 	}
@@ -117,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 }
 
 
-func (f *Interface) Handshake(vpnIp netip.Addr) {
-	f.getOrHandshake(vpnIp, nil)
+func (f *Interface) Handshake(vpnAddr netip.Addr) {
+	f.getOrHandshake(vpnAddr, nil)
 }
 }
 
 
-// getOrHandshake returns nil if the vpnIp is not routable.
+// getOrHandshake returns nil if the vpnAddr is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !f.myVpnNet.Contains(vpnIp) {
-		vpnIp = f.inside.RouteFor(vpnIp)
-		if !vpnIp.IsValid() {
+func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
+	if !found {
+		vpnAddr = f.inside.RouteFor(vpnAddr)
+		if !vpnAddr.IsValid() {
 			return nil, false
 			return nil, false
 		}
 		}
 	}
 	}
 
 
-	return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback)
+	return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
 }
 }
 
 
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@@ -156,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 }
 
 
-// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
+func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
+	hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 	})
 
 
 	if hostInfo == nil {
 	if hostInfo == nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", vpnIp).
-				Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes")
+			f.l.WithField("vpnAddr", vpnAddr).
+				Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes")
 		}
 		}
 		return
 		return
 	}
 	}
@@ -285,14 +291,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	f.connectionManager.Out(hostinfo.localIndexId)
 	f.connectionManager.Out(hostinfo.localIndexId)
 
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
-	// all our IPs and enable a faster roaming.
+	// all our addrs and enable a faster roaming.
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp)
+		f.lightHouse.QueryServer(hostinfo.vpnAddrs[0])
 		hostinfo.lastRebindCount = f.rebindCount
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter")
 		}
 		}
 	}
 	}
 
 
@@ -324,7 +330,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 	} else {
 		// Try to send via a relay
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP)
 			if err != nil {
 			if err != nil {
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")

+ 71 - 60
interface.go

@@ -2,17 +2,16 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
-	"encoding/binary"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
 	"net/netip"
 	"net/netip"
 	"os"
 	"os"
 	"runtime"
 	"runtime"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -29,7 +28,6 @@ type InterfaceConfig struct {
 	Outside                 udp.Conn
 	Outside                 udp.Conn
 	Inside                  overlay.Device
 	Inside                  overlay.Device
 	pki                     *PKI
 	pki                     *PKI
-	Cipher                  string
 	Firewall                *Firewall
 	Firewall                *Firewall
 	ServeDns                bool
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
 	HandshakeManager        *HandshakeManager
@@ -53,25 +51,27 @@ type InterfaceConfig struct {
 }
 }
 
 
 type Interface struct {
 type Interface struct {
-	hostMap            *HostMap
-	outside            udp.Conn
-	inside             overlay.Device
-	pki                *PKI
-	cipher             string
-	firewall           *Firewall
-	connectionManager  *connectionManager
-	handshakeManager   *HandshakeManager
-	serveDns           bool
-	createTime         time.Time
-	lightHouse         *LightHouse
-	myBroadcastAddr    netip.Addr
-	myVpnNet           netip.Prefix
-	dropLocalBroadcast bool
-	dropMulticast      bool
-	routines           int
-	disconnectInvalid  atomic.Bool
-	closed             atomic.Bool
-	relayManager       *relayManager
+	hostMap               *HostMap
+	outside               udp.Conn
+	inside                overlay.Device
+	pki                   *PKI
+	firewall              *Firewall
+	connectionManager     *connectionManager
+	handshakeManager      *HandshakeManager
+	serveDns              bool
+	createTime            time.Time
+	lightHouse            *LightHouse
+	myBroadcastAddrsTable *bart.Table[struct{}]
+	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
+	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
+	myVpnNetworks         []netip.Prefix        // A table of networks assigned to us via our certificate
+	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
+	dropLocalBroadcast    bool
+	dropMulticast         bool
+	routines              int
+	disconnectInvalid     atomic.Bool
+	closed                atomic.Bool
+	relayManager          *relayManager
 
 
 	tryPromoteEvery atomic.Uint32
 	tryPromoteEvery atomic.Uint32
 	reQueryEvery    atomic.Uint32
 	reQueryEvery    atomic.Uint32
@@ -103,9 +103,11 @@ type EncWriter interface {
 		out []byte,
 		out []byte,
 		nocopy bool,
 		nocopy bool,
 	)
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
+	SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp netip.Addr)
+	Handshake(vpnAddr netip.Addr)
+	GetHostInfo(vpnAddr netip.Addr) *HostInfo
+	GetCertState() *CertState
 }
 }
 
 
 type sendRecvErrorConfig uint8
 type sendRecvErrorConfig uint8
@@ -116,10 +118,10 @@ const (
 	sendRecvErrorPrivate
 	sendRecvErrorPrivate
 )
 )
 
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool {
 	switch s {
 	switch s {
 	case sendRecvErrorPrivate:
 	case sendRecvErrorPrivate:
-		return ip.Addr().IsPrivate()
+		return endpoint.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 	case sendRecvErrorAlways:
 		return true
 		return true
 	case sendRecvErrorNever:
 	case sendRecvErrorNever:
@@ -156,27 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		return nil, errors.New("no firewall rules")
 		return nil, errors.New("no firewall rules")
 	}
 	}
 
 
-	certificate := c.pki.GetCertState().Certificate
-
+	cs := c.pki.getCertState()
 	ifce := &Interface{
 	ifce := &Interface{
-		pki:                c.pki,
-		hostMap:            c.HostMap,
-		outside:            c.Outside,
-		inside:             c.Inside,
-		cipher:             c.Cipher,
-		firewall:           c.Firewall,
-		serveDns:           c.ServeDns,
-		handshakeManager:   c.HandshakeManager,
-		createTime:         time.Now(),
-		lightHouse:         c.lightHouse,
-		dropLocalBroadcast: c.DropLocalBroadcast,
-		dropMulticast:      c.DropMulticast,
-		routines:           c.routines,
-		version:            c.version,
-		writers:            make([]udp.Conn, c.routines),
-		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnNet:           certificate.Networks()[0],
-		relayManager:       c.relayManager,
+		pki:                   c.pki,
+		hostMap:               c.HostMap,
+		outside:               c.Outside,
+		inside:                c.Inside,
+		firewall:              c.Firewall,
+		serveDns:              c.ServeDns,
+		handshakeManager:      c.HandshakeManager,
+		createTime:            time.Now(),
+		lightHouse:            c.lightHouse,
+		dropLocalBroadcast:    c.DropLocalBroadcast,
+		dropMulticast:         c.DropMulticast,
+		routines:              c.routines,
+		version:               c.version,
+		writers:               make([]udp.Conn, c.routines),
+		readers:               make([]io.ReadWriteCloser, c.routines),
+		myVpnNetworks:         cs.myVpnNetworks,
+		myVpnNetworksTable:    cs.myVpnNetworksTable,
+		myVpnAddrs:            cs.myVpnAddrs,
+		myVpnAddrsTable:       cs.myVpnAddrsTable,
+		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
+		relayManager:          c.relayManager,
 
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 
@@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 		l: c.l,
 	}
 	}
 
 
-	if ifce.myVpnNet.Addr().Is4() {
-		maskedAddr := certificate.Networks()[0].Masked()
-		addr := maskedAddr.Addr().As4()
-		mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen())
-		binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
-		ifce.myBroadcastAddr = netip.AddrFrom4(addr)
-	}
-
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
@@ -218,7 +214,7 @@ func (f *Interface) activate() {
 		f.l.WithError(err).Error("Failed to get udp listen address")
 		f.l.WithError(err).Error("Failed to get udp listen address")
 	}
 	}
 
 
-	f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()).
+	f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("build", f.version).WithField("udpAddr", addr).
 		WithField("boringcrypto", boringEnabled()).
 		WithField("boringcrypto", boringEnabled()).
 		Info("Nebula interface is active")
 		Info("Nebula interface is active")
@@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) {
 	runtime.LockOSThread()
 	runtime.LockOSThread()
 
 
 	var li udp.Conn
 	var li udp.Conn
-	// TODO clean this up with a coherent interface for each outside connection
 	if i > 0 {
 	if i > 0 {
 		li = f.writers[i]
 		li = f.writers[i]
 	} else {
 	} else {
 		li = f.outside
 		li = f.outside
 	}
 	}
 
 
+	ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
 	lhh := f.lightHouse.NewRequestHandler()
 	lhh := f.lightHouse.NewRequestHandler()
-	conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout)
-	li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i)
+	plaintext := make([]byte, udp.MTU)
+	h := &header.H{}
+	fwPacket := &firewall.Packet{}
+	nb := make([]byte, 12, 12)
+
+	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
+		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+	})
 }
 }
 
 
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
@@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 		return
 	}
 	}
 
 
-	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
 		return
@@ -417,11 +419,20 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
 			udpStats()
-			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second))
+			certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second))
+			//TODO: we should also report the default certificate version
 		}
 		}
 	}
 	}
 }
 }
 
 
+func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return f.hostMap.QueryVpnAddr(vpnIp)
+}
+
+func (f *Interface) GetCertState() *CertState {
+	return f.pki.getCertState()
+}
+
 func (f *Interface) Close() error {
 func (f *Interface) Close() error {
 	f.closed.Store(true)
 	f.closed.Store(true)
 
 

File diff suppressed because it is too large
+ 490 - 298
lighthouse.go


+ 93 - 66
lighthouse_test.go

@@ -7,6 +7,8 @@ import (
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 
 
+	"github.com/gaissmai/bart"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
@@ -19,57 +21,48 @@ import (
 func TestOldIPv4Only(t *testing.T) {
 func TestOldIPv4Only(t *testing.T) {
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	// This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
 	b := []byte{8, 129, 130, 132, 80, 16, 10}
-	var m Ip4AndPort
+	var m V4AddrPort
 	err := m.Unmarshal(b)
 	err := m.Unmarshal(b)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	ip := netip.MustParseAddr("10.1.1.1")
 	ip := netip.MustParseAddr("10.1.1.1")
 	bp := ip.As4()
 	bp := ip.As4()
-	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
-}
-
-func TestNewLhQuery(t *testing.T) {
-	myIp, err := netip.ParseAddr("192.1.1.1")
-	assert.NoError(t, err)
-
-	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIp)
-
-	// The result should be a nebulameta protobuf
-	assert.IsType(t, &NebulaMeta{}, a)
-
-	// It should also Marshal fine
-	b, err := a.Marshal()
-	assert.Nil(t, err)
-
-	// and then Unmarshal fine
-	n := &NebulaMeta{}
-	err = n.Unmarshal(b)
-	assert.Nil(t, err)
-
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
 }
 }
 
 
 func Test_lhStaticMapping(t *testing.T) {
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	_, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	lh2 := "10.128.0.3"
 	lh2 := "10.128.0.3"
 	c = config.NewC(l)
 	c = config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
-	_, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 	assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 }
 
 
 func TestReloadLighthouseInterval(t *testing.T) {
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
@@ -79,7 +72,7 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	}
 	}
 
 
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
 	c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	lh.ifce = &mockEncWriter{}
 	lh.ifce = &mockEncWriter{}
 
 
@@ -99,9 +92,15 @@ func TestReloadLighthouseInterval(t *testing.T) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
 	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	if !assert.NoError(b, err) {
 	if !assert.NoError(b, err) {
 		b.Fatal()
 		b.Fatal()
 	}
 	}
@@ -110,46 +109,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
 	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
 
 
 	vpnIp3 := netip.MustParseAddr("0.0.0.3")
 	vpnIp3 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp3] = NewRemoteList(nil)
+	lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil)
 	lh.addrMap[vpnIp3].unlockedSetV4(
 	lh.addrMap[vpnIp3].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
-			NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()),
+			netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()),
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
 	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
 	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
 	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
 	vpnIp2 := netip.MustParseAddr("0.0.0.3")
 	vpnIp2 := netip.MustParseAddr("0.0.0.3")
-	lh.addrMap[vpnIp2] = NewRemoteList(nil)
+	lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil)
 	lh.addrMap[vpnIp2].unlockedSetV4(
 	lh.addrMap[vpnIp2].unlockedSetV4(
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
 		vpnIp3,
-		[]*Ip4AndPort{
-			NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
-			NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
+		[]*V4AddrPort{
+			netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()),
+			netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()),
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 
 
+	hi := []netip.Addr{vpnIp2}
 	b.Run("notfound", func(b *testing.B) {
 	b.Run("notfound", func(b *testing.B) {
 		lhh := lh.NewRequestHandler()
 		lhh := lh.NewRequestHandler()
 		req := &NebulaMeta{
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
 			Details: &NebulaMetaDetails{
-				VpnIp:       4,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  4,
+				V4AddrPorts: nil,
 			},
 			},
 		}
 		}
 		p, err := req.Marshal()
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		assert.NoError(b, err)
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 		}
 	})
 	})
 	b.Run("found", func(b *testing.B) {
 	b.Run("found", func(b *testing.B) {
@@ -157,15 +157,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		req := &NebulaMeta{
 		req := &NebulaMeta{
 			Type: NebulaMeta_HostQuery,
 			Type: NebulaMeta_HostQuery,
 			Details: &NebulaMetaDetails{
 			Details: &NebulaMetaDetails{
-				VpnIp:       3,
-				Ip4AndPorts: nil,
+				OldVpnAddr:  3,
+				V4AddrPorts: nil,
 			},
 			},
 		}
 		}
 		p, err := req.Marshal()
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		assert.NoError(b, err)
 
 
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
+			lhh.HandleRequest(rAddr, hi, p, mw)
 		}
 		}
 	})
 	})
 }
 }
@@ -197,40 +197,49 @@ func TestLighthouse_Memory(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
+	lh.ifce = &mockEncWriter{}
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 	lhh := lh.NewRequestHandler()
 
 
 	// Test that my first update responds with just that
 	// Test that my first update responds with just that
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2)
 
 
 	// Ensure we don't accumulate addresses
 	// Ensure we don't accumulate addresses
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3)
 
 
 	// Grow it back to 2
 	// Grow it back to 2
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Update a different host and ask about it
 	// Update a different host and ask about it
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 
 	// Have both hosts ask about the other
 	// Have both hosts ask about the other
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 
 	// Make sure we didn't get changed
 	// Make sure we didn't get changed
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Ensure proper ordering and limiting
 	// Ensure proper ordering and limiting
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
 	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
@@ -255,7 +264,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 	assertIp4InArray(
 		t,
 		t,
-		r.msg.Details.Ip4AndPorts,
+		r.msg.Details.V4AddrPorts,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 	)
 
 
@@ -265,7 +274,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	good := netip.MustParseAddrPort("1.128.0.99:4242")
 	good := netip.MustParseAddrPort("1.128.0.99:4242")
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
-	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
+	assertIp4InArray(t, r.msg.Details.V4AddrPorts, good)
 }
 }
 
 
 func TestLighthouse_reload(t *testing.T) {
 func TestLighthouse_reload(t *testing.T) {
@@ -273,7 +282,16 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
+
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/24")
+	nt := new(bart.Table[struct{}])
+	nt.Insert(myVpnNet, struct{}{})
+	cs := &CertState{
+		myVpnNetworks:      []netip.Prefix{myVpnNet},
+		myVpnNetworksTable: nt,
+	}
+
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	nc := map[interface{}]interface{}{
 	nc := map[interface{}]interface{}{
@@ -295,7 +313,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
 	req := &NebulaMeta{
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp: binary.BigEndian.Uint32(bip[:]),
+			OldVpnAddr: binary.BigEndian.Uint32(bip[:]),
 		},
 		},
 	}
 	}
 
 
@@ -308,7 +326,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
 	w := &testEncWriter{
 	w := &testEncWriter{
 		metaFilter: &filter,
 		metaFilter: &filter,
 	}
 	}
-	lhh.HandleRequest(fromAddr, myVpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
 	return w.lastReply
 	return w.lastReply
 }
 }
 
 
@@ -318,13 +336,13 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
 	req := &NebulaMeta{
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp:       binary.BigEndian.Uint32(bip[:]),
-			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
+			OldVpnAddr:  binary.BigEndian.Uint32(bip[:]),
+			V4AddrPorts: make([]*V4AddrPort, len(addrs)),
 		},
 		},
 	}
 	}
 
 
 	for k, v := range addrs {
 	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
+		req.Details.V4AddrPorts[k] = netAddrToProtoV4AddrPort(v.Addr(), v.Port())
 	}
 	}
 
 
 	b, err := req.Marshal()
 	b, err := req.Marshal()
@@ -333,7 +351,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
 	}
 	}
 
 
 	w := &testEncWriter{}
 	w := &testEncWriter{}
-	lhh.HandleRequest(fromAddr, vpnIp, b, w)
+	lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
 }
 }
 
 
 //TODO: this is a RemoteList test
 //TODO: this is a RemoteList test
@@ -410,8 +428,9 @@ type testLhReply struct {
 }
 }
 
 
 type testEncWriter struct {
 type testEncWriter struct {
-	lastReply  testLhReply
-	metaFilter *NebulaMeta_MessageType
+	lastReply       testLhReply
+	metaFilter      *NebulaMeta_MessageType
+	protocolVersion cert.Version
 }
 }
 
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
@@ -426,7 +445,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 		tw.lastReply = testLhReply{
 		tw.lastReply = testLhReply{
 			nebType:    t,
 			nebType:    t,
 			nebSubType: st,
 			nebSubType: st,
-			vpnIp:      hostinfo.vpnIp,
+			vpnIp:      hostinfo.vpnAddrs[0],
 			msg:        msg,
 			msg:        msg,
 		}
 		}
 	}
 	}
@@ -436,7 +455,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 	}
 }
 }
 
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -453,15 +472,23 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 	}
 	}
 }
 }
 
 
+func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
+	return nil
+}
+
+func (tw *testEncWriter) GetCertState() *CertState {
+	return &CertState{defaultVersion: tw.protocolVersion}
+}
+
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
+func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 	if !assert.Len(t, have, len(want)) {
 		return
 		return
 	}
 	}
 
 
 	for k, w := range want {
 	for k, w := range want {
 		//TODO: IPV6-WORK
 		//TODO: IPV6-WORK
-		h := AddrPortFromIp4AndPort(have[k])
+		h := protoV4AddrPortToNetAddrPort(have[k])
 		if !(h == w) {
 		if !(h == w) {
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 		}

+ 5 - 19
main.go

@@ -2,7 +2,6 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
-	"encoding/binary"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
@@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
 	}
 
 
-	certificate := pki.GetCertState().Certificate
-	fw, err := NewFirewallFromConfig(l, certificate, c)
+	fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
 
-	tunCidr := certificate.Networks()[0]
-
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
 		return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
@@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 			deviceFactory = overlay.NewDeviceFromConfig
 			deviceFactory = overlay.NewDeviceFromConfig
 		}
 		}
 
 
-		tun, err = deviceFactory(c, l, tunCidr, routines)
+		tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
 		if err != nil {
 		if err != nil {
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
 		}
@@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 		}
 	}
 	}
 
 
-	hostMap := NewHostMapFromConfig(l, tunCidr, c)
+	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
+	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
 	}
@@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		Inside:                  tun,
 		Inside:                  tun,
 		Outside:                 udpConns[0],
 		Outside:                 udpConns[0],
 		pki:                     pki,
 		pki:                     pki,
-		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		Firewall:                fw,
 		ServeDns:                serveDns,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
 		HandshakeManager:        handshakeManager,
@@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		l:                     l,
 		l:                     l,
 	}
 	}
 
 
-	switch ifConfig.Cipher {
-	case "aes":
-		noiseEndianness = binary.BigEndian
-	case "chachapoly":
-		noiseEndianness = binary.LittleEndian
-	default:
-		return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher)
-	}
-
 	var ifce *Interface
 	var ifce *Interface
 	if !configTest {
 	if !configTest {
 		ifce, err = NewInterface(ctx, ifConfig)
 		ifce, err = NewInterface(ctx, ifConfig)
@@ -303,7 +289,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	var dnsStart func()
 	var dnsStart func()
 	if lightHouse.amLighthouse && serveDns {
 	if lightHouse.amLighthouse && serveDns {
 		l.Debugln("Starting dns server")
 		l.Debugln("Starting dns server")
-		dnsStart = dnsMain(l, hostMap, c)
+		dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
 	}
 	}
 
 
 	return &Control{
 	return &Control{

File diff suppressed because it is too large
+ 467 - 171
nebula.pb.go


+ 23 - 9
nebula.proto

@@ -23,19 +23,28 @@ message NebulaMeta {
 }
 }
 
 
 message NebulaMetaDetails {
 message NebulaMetaDetails {
-  uint32 VpnIp = 1;
-  repeated Ip4AndPort Ip4AndPorts = 2;
-  repeated Ip6AndPort Ip6AndPorts = 4;
-  repeated uint32 RelayVpnIp = 5;
+  uint32 OldVpnAddr = 1 [deprecated = true];
+  Addr VpnAddr = 6;
+
+  repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true];
+  repeated Addr RelayVpnAddrs = 7;
+
+  repeated V4AddrPort V4AddrPorts = 2;
+  repeated V6AddrPort V6AddrPorts = 4;
   uint32 counter = 3;
   uint32 counter = 3;
 }
 }
 
 
-message Ip4AndPort {
-  uint32 Ip = 1;
+message Addr {
+  uint64 Hi = 1;
+  uint64 Lo = 2;
+}
+
+message V4AddrPort {
+  uint32 Addr = 1;
   uint32 Port = 2;
   uint32 Port = 2;
 }
 }
 
 
-message Ip6AndPort {
+message V6AddrPort {
   uint64 Hi = 1;
   uint64 Hi = 1;
   uint64 Lo = 2;
   uint64 Lo = 2;
   uint32 Port = 3;
   uint32 Port = 3;
@@ -62,6 +71,7 @@ message NebulaHandshakeDetails {
   uint32 ResponderIndex = 3;
   uint32 ResponderIndex = 3;
   uint64 Cookie = 4;
   uint64 Cookie = 4;
   uint64 Time = 5;
   uint64 Time = 5;
+  uint32 CertVersion = 8;
   // reserved for WIP multiport
   // reserved for WIP multiport
   reserved 6, 7;
   reserved 6, 7;
 }
 }
@@ -76,6 +86,10 @@ message NebulaControl {
 
 
   uint32 InitiatorRelayIndex = 2;
   uint32 InitiatorRelayIndex = 2;
   uint32 ResponderRelayIndex = 3;
   uint32 ResponderRelayIndex = 3;
-  uint32 RelayToIp = 4;
-  uint32 RelayFromIp = 5;
+
+  uint32 OldRelayToAddr = 4 [deprecated = true];
+  uint32 OldRelayFromAddr = 5 [deprecated = true];
+
+  Addr RelayToAddr = 6;
+  Addr RelayFromAddr = 7;
 }
 }

+ 135 - 82
outside.go

@@ -7,12 +7,12 @@ import (
 	"net/netip"
 	"net/netip"
 	"time"
 	"time"
 
 
-	"github.com/flynn/noise"
+	"github.com/google/gopacket/layers"
+	"golang.org/x/net/ipv6"
+
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
 
 
@@ -20,24 +20,7 @@ const (
 	minFwPacketLen = 4
 	minFwPacketLen = 4
 )
 )
 
 
-// TODO: IPV6-WORK this can likely be removed now
-func readOutsidePackets(f *Interface) udp.EncReader {
-	return func(
-		addr netip.AddrPort,
-		out []byte,
-		packet []byte,
-		header *header.H,
-		fwPacket *firewall.Packet,
-		lhh udp.LightHouseHandlerFunc,
-		nb []byte,
-		q int,
-		localCache firewall.ConntrackCache,
-	) {
-		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
-	}
-}
-
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	err := h.Parse(packet)
 	if err != nil {
 	if err != nil {
 		// TODO: best if we return this and let caller log
 		// TODO: best if we return this and let caller log
@@ -51,7 +34,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	if ip.IsValid() {
 	if ip.IsValid() {
-		if f.myVpnNet.Contains(ip.Addr()) {
+		_, found := f.myVpnNetworksTable.Lookup(ip.Addr())
+		if found {
 			if f.l.Level >= logrus.DebugLevel {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
 			}
@@ -108,7 +92,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			if !ok {
 			if !ok {
 				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
 				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
 				// its internal mapping. This should never happen.
 				// its internal mapping. This should never happen.
-				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
+				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
 				return
 				return
 			}
 			}
 
 
@@ -120,9 +104,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				return
 				return
 			case ForwardingType:
 			case ForwardingType:
 				// Find the target HostInfo relay object
 				// Find the target HostInfo relay object
-				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
+				targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
 				if err != nil {
 				if err != nil {
-					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
+					hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip")
 					return
 					return
 				}
 				}
 
 
@@ -138,7 +122,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
 					}
 					}
 				} else {
 				} else {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
+					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
 					return
 					return
 				}
 				}
 			}
 			}
@@ -161,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			return
 			return
 		}
 		}
 
 
-		lhf(ip, hostinfo.vpnIp, d)
+		lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
 
 
 		// Fallthrough to the bottom to record incoming traffic
 		// Fallthrough to the bottom to record incoming traffic
 
 
@@ -228,14 +212,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 				Error("Failed to decrypt Control packet")
 				Error("Failed to decrypt Control packet")
 			return
 			return
 		}
 		}
-		m := &NebulaControl{}
-		err = m.Unmarshal(d)
-		if err != nil {
-			hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
-			break
-		}
 
 
-		f.relayManager.HandleControlMsg(hostinfo, m, f)
+		f.relayManager.HandleControlMsg(hostinfo, d, f)
 
 
 	default:
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
@@ -252,8 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	if final {
 	if final {
-		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
-		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
+		// We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage
+		f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs)
 	}
 	}
 }
 }
 
 
@@ -262,25 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 }
 
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
-	if ip.IsValid() && hostinfo.remote != ip {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) {
+	if vpnAddr.IsValid() && hostinfo.remote != vpnAddr {
+		//TODO: this is weird now that we can have multiple vpn addrs
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 			return
 		}
 		}
-		if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+		if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			}
 			return
 			return
 		}
 		}
 
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
 			Info("Host roamed to new udp ip/port.")
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(ip)
+		hostinfo.SetRemote(vpnAddr)
 	}
 	}
 
 
 }
 }
@@ -302,14 +281,114 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
 
 
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
-	// Do we at least have an ipv4 header worth of data?
-	if len(data) < ipv4.HeaderLen {
-		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
+	if len(data) < 1 {
+		return errors.New("packet too short")
+	}
+
+	version := int((data[0] >> 4) & 0x0f)
+	switch version {
+	case ipv4.Version:
+		return parseV4(data, incoming, fp)
+	case ipv6.Version:
+		return parseV6(data, incoming, fp)
 	}
 	}
+	return fmt.Errorf("packet is an unknown ip version: %v", version)
+}
 
 
-	// Is it an ipv4 packet?
-	if int((data[0]>>4)&0x0f) != 4 {
-		return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
+func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
+	dataLen := len(data)
+	if dataLen < ipv6.HeaderLen {
+		return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
+	}
+
+	if incoming {
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
+	} else {
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
+	}
+
+	//TODO: whats a reasonable number of extension headers to attempt to parse?
+	//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
+	protoAt := 6
+	offset := 40
+	for i := 0; i < 24; i++ {
+		if dataLen < offset {
+			break
+		}
+
+		proto := layers.IPProtocol(data[protoAt])
+		//fmt.Println(proto, protoAt)
+		switch proto {
+		case layers.IPProtocolICMPv6:
+			//TODO: we need a new protocol in config language "icmpv6"
+			fp.Protocol = uint8(proto)
+			fp.RemotePort = 0
+			fp.LocalPort = 0
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolTCP:
+			if dataLen < offset+4 {
+				return fmt.Errorf("ipv6 packet was too small")
+			}
+			fp.Protocol = uint8(proto)
+			if incoming {
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			} else {
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			}
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolUDP:
+			if dataLen < offset+4 {
+				return fmt.Errorf("ipv6 packet was too small")
+			}
+			fp.Protocol = uint8(proto)
+			if incoming {
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			} else {
+				fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
+				fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
+			}
+			fp.Fragment = false
+			return nil
+
+		case layers.IPProtocolIPv6Fragment:
+			//TODO: can we determine the protocol?
+			fp.RemotePort = 0
+			fp.LocalPort = 0
+			fp.Fragment = true
+			return nil
+
+		default:
+			if dataLen < offset+1 {
+				break
+			}
+
+			next := int(data[offset+1]) * 8
+			if next == 0 {
+				// each extension is at least 8 bytes
+				next = 8
+			}
+
+			protoAt = offset
+			offset = offset + next
+		}
+	}
+
+	return fmt.Errorf("could not find payload in ipv6 packet")
+}
+
+func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
+	// Do we at least have an ipv4 header worth of data?
+	if len(data) < ipv4.HeaderLen {
+		return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
 	}
 	}
 
 
 	// Adjust our start position based on the advertised ip header length
 	// Adjust our start position based on the advertised ip header length
@@ -317,7 +396,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 
 
 	// Well formed ip header length?
 	// Well formed ip header length?
 	if ihl < ipv4.HeaderLen {
 	if ihl < ipv4.HeaderLen {
-		return fmt.Errorf("packet had an invalid header length: %v", ihl)
+		return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
 	}
 	}
 
 
 	// Check if this is the second or further fragment of a fragmented packet.
 	// Check if this is the second or further fragment of a fragmented packet.
@@ -333,14 +412,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 		minLen += minFwPacketLen
 		minLen += minFwPacketLen
 	}
 	}
 	if len(data) < minLen {
 	if len(data) < minLen {
-		return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
+		return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
 	}
 	}
 
 
 	// Firewall packets are locally oriented
 	// Firewall packets are locally oriented
 	if incoming {
 	if incoming {
-		//TODO: IPV6-WORK
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
@@ -349,9 +427,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 		}
 	} else {
 	} else {
-		//TODO: IPV6-WORK
-		fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
-		fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
+		fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
+		fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
@@ -492,27 +569,3 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
 	f.outside.WriteTo(msg, endpoint)
 	f.outside.WriteTo(msg, endpoint)
 }
 }
 */
 */
-
-func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) {
-	pk := h.PeerStatic()
-
-	if pk == nil {
-		return nil, errors.New("no peer static key was present")
-	}
-
-	if rawCertBytes == nil {
-		return nil, errors.New("provided payload was empty")
-	}
-
-	c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk)
-	if err != nil {
-		return nil, fmt.Errorf("error unmarshaling cert: %w", err)
-	}
-
-	cc, err := caPool.VerifyCertificate(time.Now(), c)
-	if err != nil {
-		return nil, fmt.Errorf("certificate validation failed: %w", err)
-	}
-
-	return cc, nil
-}

+ 71 - 10
outside_test.go

@@ -5,6 +5,9 @@ import (
 	"net/netip"
 	"net/netip"
 	"testing"
 	"testing"
 
 
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/layers"
+
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
@@ -13,9 +16,15 @@ import (
 func Test_newPacket(t *testing.T) {
 func Test_newPacket(t *testing.T) {
 	p := &firewall.Packet{}
 	p := &firewall.Packet{}
 
 
-	// length fail
-	err := newPacket([]byte{0, 1}, true, p)
-	assert.EqualError(t, err, "packet is less than 20 bytes")
+	// length fails
+	err := newPacket([]byte{}, true, p)
+	assert.EqualError(t, err, "packet too short")
+
+	err = newPacket([]byte{0x40}, true, p)
+	assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")
+
+	err = newPacket([]byte{0x60}, true, p)
+	assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")
 
 
 	// length fail with ip options
 	// length fail with ip options
 	h := ipv4.Header{
 	h := ipv4.Header{
@@ -29,15 +38,15 @@ func Test_newPacket(t *testing.T) {
 	b, _ := h.Marshal()
 	b, _ := h.Marshal()
 	err = newPacket(b, true, p)
 	err = newPacket(b, true, p)
 
 
-	assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
+	assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")
 
 
 	// not an ipv4 packet
 	// not an ipv4 packet
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
 	err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet is not ipv4, type: 0")
+	assert.EqualError(t, err, "packet is an unknown ip version: 0")
 
 
 	// invalid ihl
 	// invalid ihl
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
 	err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
-	assert.EqualError(t, err, "packet had an invalid header length: 8")
+	assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")
 
 
 	// account for variable ip header length - incoming
 	// account for variable ip header length - incoming
 	h = ipv4.Header{
 	h = ipv4.Header{
@@ -55,8 +64,8 @@ func Test_newPacket(t *testing.T) {
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
 	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
+	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2"))
+	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1"))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.LocalPort, uint16(4))
 	assert.Equal(t, p.LocalPort, uint16(4))
 
 
@@ -76,8 +85,60 @@ func Test_newPacket(t *testing.T) {
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(2))
 	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
-	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
+	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1"))
+	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2"))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.LocalPort, uint16(5))
 	assert.Equal(t, p.LocalPort, uint16(5))
 }
 }
+
+func Test_newPacket_v6(t *testing.T) {
+	p := &firewall.Packet{}
+
+	ip := layers.IPv6{
+		Version:    6,
+		NextHeader: firewall.ProtoUDP,
+		HopLimit:   128,
+		SrcIP:      net.IPv6linklocalallrouters,
+		DstIP:      net.IPv6linklocalallnodes,
+	}
+
+	udp := layers.UDP{
+		SrcPort: layers.UDPPort(36123),
+		DstPort: layers.UDPPort(22),
+	}
+	err := udp.SetNetworkLayerForChecksum(&ip)
+	if err != nil {
+		panic(err)
+	}
+
+	buffer := gopacket.NewSerializeBuffer()
+	opt := gopacket.SerializeOptions{
+		ComputeChecksums: true,
+		FixLengths:       true,
+	}
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
+	if err != nil {
+		panic(err)
+	}
+	b := buffer.Bytes()
+
+	//test incoming
+	err = newPacket(b, true, p)
+
+	assert.Nil(t, err)
+	assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
+	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2"))
+	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1"))
+	assert.Equal(t, p.RemotePort, uint16(36123))
+	assert.Equal(t, p.LocalPort, uint16(22))
+
+	//test outgoing
+	err = newPacket(b, false, p)
+
+	assert.Nil(t, err)
+	assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
+	assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2"))
+	assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1"))
+	assert.Equal(t, p.LocalPort, uint16(36123))
+	assert.Equal(t, p.RemotePort, uint16(22))
+}

+ 1 - 1
overlay/device.go

@@ -8,7 +8,7 @@ import (
 type Device interface {
 type Device interface {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	Activate() error
 	Activate() error
-	Cidr() netip.Prefix
+	Networks() []netip.Prefix
 	Name() string
 	Name() string
 	RouteFor(netip.Addr) netip.Addr
 	RouteFor(netip.Addr) netip.Addr
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 	NewMultiQueueReader() (io.ReadWriteCloser, error)

+ 22 - 12
overlay/route.go

@@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table
 	return routeTree, nil
 	return routeTree, nil
 }
 }
 
 
-func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.routes")
 	r := c.Get("tun.routes")
@@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 		}
 		}
 
 
-		if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
+		found := false
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() {
+				found = true
+				break
+			}
+		}
+
+		if !found {
 			return nil, fmt.Errorf(
 			return nil, fmt.Errorf(
-				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
+				"entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v",
 				i+1,
 				i+1,
 				r.Cidr.String(),
 				r.Cidr.String(),
-				network.String(),
+				networks,
 			)
 			)
 		}
 		}
 
 
@@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	return routes, nil
 	return routes, nil
 }
 }
 
 
-func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.unsafe_routes")
 	r := c.Get("tun.unsafe_routes")
@@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 		}
 		}
 
 
-		if network.Contains(r.Cidr.Addr()) {
-			return nil, fmt.Errorf(
-				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
-				i+1,
-				r.Cidr.String(),
-				network.String(),
-			)
+		for _, network := range networks {
+			if network.Contains(r.Cidr.Addr()) {
+				return nil, fmt.Errorf(
+					"entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v",
+					i+1,
+					r.Cidr.String(),
+					network.String(),
+				)
+			}
 		}
 		}
 
 
 		routes[i] = r
 		routes[i] = r

+ 39 - 33
overlay/route_test.go

@@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	// test no routes config
 	// test no routes config
-	routes, err := parseRoutes(c, n)
+	routes, err := parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// not an array
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.routes is not an array")
 	assert.EqualError(t, err, "tun.routes is not an array")
 
 
 	// no routes
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// weird route
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 	assert.EqualError(t, err, "entry 1 in tun.routes is invalid")
 
 
 	// no mtu
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
 
 
 	// bad mtu
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 
 	// low mtu
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 	assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 	assert.EqualError(t, err, "entry 1.route in tun.routes is not present")
 
 
 	// unparsable route
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
 
 
 	// above network range
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
+
+	// Not in multiple ranges
+	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
+	routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
+	assert.Nil(t, routes)
+	assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
 
 
 	// happy case
 	// happy case
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 		map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
 	}}
 	}}
-	routes, err = parseRoutes(c, n)
+	routes, err = parseRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 2)
 	assert.Len(t, routes, 2)
 
 
@@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	// test no routes config
 	// test no routes config
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// not an array
 	// not an array
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 	assert.EqualError(t, err, "tun.unsafe_routes is not an array")
 
 
 	// no routes
 	// no routes
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 0)
 	assert.Len(t, routes, 0)
 
 
 	// weird route
 	// weird route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 	assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
 
 
 	// no via
 	// no via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
 
 
@@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		127, false, nil, 1.0, []string{"1", "2"},
 		127, false, nil, 1.0, []string{"1", "2"},
 	} {
 	} {
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
 		c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
-		routes, err = parseUnsafeRoutes(c, n)
+		routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 		assert.Nil(t, routes)
 		assert.Nil(t, routes)
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 		assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
 	}
 	}
 
 
 	// unparsable via
 	// unparsable via
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
 
 
 	// unparsable route
 	// unparsable route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 
 	// within network range
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24")
+	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	// above network range
 	// above network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	// no mtu
 	// no mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Len(t, routes, 1)
 	assert.Len(t, routes, 1)
 	assert.Equal(t, 0, routes[0].MTU)
 	assert.Equal(t, 0, routes[0].MTU)
 
 
 	// bad mtu
 	// bad mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
 
 
 	// low mtu
 	// low mtu
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 	assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
 
 
 	// bad install
 	// bad install
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
 	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 	assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
 
 
@@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 		map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
 	}}
 	}}
-	routes, err = parseUnsafeRoutes(c, n)
+	routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Len(t, routes, 4)
 	assert.Len(t, routes, 4)
 
 
@@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) {
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 		map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
 	}}
 	}}
-	routes, err := parseUnsafeRoutes(c, n)
+	routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	assert.Len(t, routes, 2)
 	assert.Len(t, routes, 2)
 	routeTree, err := makeRouteTree(l, routes, true)
 	routeTree, err := makeRouteTree(l, routes, true)

+ 9 - 9
overlay/tun.go

@@ -11,36 +11,36 @@ import (
 const DefaultMTU = 1300
 const DefaultMTU = 1300
 
 
 // TODO: We may be able to remove routines
 // TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
 
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
 	switch {
 	switch {
 	case c.GetBool("tun.disabled", false):
 	case c.GetBool("tun.disabled", false):
-		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
+		tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		return tun, nil
 		return tun, nil
 
 
 	default:
 	default:
-		return newTun(c, l, tunCidr, routines > 1)
+		return newTun(c, l, vpnNetworks, routines > 1)
 	}
 	}
 }
 }
 
 
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
-		return newTunFromFd(c, l, *fd, tunCidr)
+	return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+		return newTunFromFd(c, l, *fd, vpnNetworks)
 	}
 	}
 }
 }
 
 
-func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 		return false, nil, nil
 		return false, nil, nil
 	}
 	}
 
 
-	routes, err := parseRoutes(c, cidr)
+	routes, err := parseRoutes(c, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 		return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err)
 	}
 	}
 
 
-	unsafeRoutes, err := parseUnsafeRoutes(c, cidr)
+	unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 		return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err)
 	}
 	}

+ 11 - 11
overlay/tun_android.go

@@ -18,14 +18,14 @@ import (
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	fd        int
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	fd          int
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              deviceFd,
 		fd:              deviceFd,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		l:               l,
 		l:               l,
 	}
 	}
 
 
@@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 }
 
 
@@ -66,7 +66,7 @@ func (t tun) Activate() error {
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 207 - 214
overlay/tun_darwin.go

@@ -24,56 +24,62 @@ import (
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	Device     string
-	cidr       netip.Prefix
-	DefaultMTU int
-	Routes     atomic.Pointer[[]Route]
-	routeTree  atomic.Pointer[bart.Table[netip.Addr]]
-	linkAddr   *netroute.LinkAddr
-	l          *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	DefaultMTU  int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	linkAddr    *netroute.LinkAddr
+	l           *logrus.Logger
 
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
 	out []byte
 	out []byte
 }
 }
 
 
-type sockaddrCtl struct {
-	scLen      uint8
-	scFamily   uint8
-	ssSysaddr  uint16
-	scID       uint32
-	scUnit     uint32
-	scReserved [5]uint32
-}
-
 type ifReq struct {
 type ifReq struct {
-	Name  [16]byte
+	Name  [unix.IFNAMSIZ]byte
 	Flags uint16
 	Flags uint16
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-var sockaddrCtlSize uintptr = 32
-
 const (
 const (
-	_SYSPROTO_CONTROL = 2              //define SYSPROTO_CONTROL 2 /* kernel control protocol */
-	_AF_SYS_CONTROL   = 2              //#define AF_SYS_CONTROL 2 /* corresponding sub address type */
-	_PF_SYSTEM        = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM
-	_CTLIOCGINFO      = 3227799043     //#define CTLIOCGINFO     _IOWR('N', 3, struct ctl_info)
-	utunControlName   = "com.apple.net.utun_control"
+	_SIOCAIFADDR_IN6 = 2155899162
+	_UTUN_OPT_IFNAME = 2
+	_IN6_IFF_NODAD   = 0x0020
+	_IN6_IFF_SECURED = 0x0400
+	utunControlName  = "com.apple.net.utun_control"
 )
 )
 
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 type ifreqMTU struct {
 	Name [16]byte
 	Name [16]byte
 	MTU  int32
 	MTU  int32
 	pad  [8]byte
 	pad  [8]byte
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+type addrLifetime struct {
+	Expire    float64
+	Preferred float64
+	Vltime    uint32
+	Pltime    uint32
+}
+
+type ifreqAlias4 struct {
+	Name     [unix.IFNAMSIZ]byte
+	Addr     unix.RawSockaddrInet4
+	DstAddr  unix.RawSockaddrInet4
+	MaskAddr unix.RawSockaddrInet4
+}
+
+type ifreqAlias6 struct {
+	Name       [unix.IFNAMSIZ]byte
+	Addr       unix.RawSockaddrInet6
+	DstAddr    unix.RawSockaddrInet6
+	PrefixMask unix.RawSockaddrInet6
+	Flags      uint32
+	Lifetime   addrLifetime
+}
+
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	name := c.GetString("tun.dev", "")
 	name := c.GetString("tun.dev", "")
 	ifIndex := -1
 	ifIndex := -1
 	if name != "" && name != "utun" {
 	if name != "" && name != "utun" {
@@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 		}
 		}
 	}
 	}
 
 
-	fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL)
+	fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("system socket: %v", err)
 		return nil, fmt.Errorf("system socket: %v", err)
 	}
 	}
 
 
-	var ctlInfo = &struct {
-		ctlID   uint32
-		ctlName [96]byte
-	}{}
+	var ctlInfo = &unix.CtlInfo{}
+	copy(ctlInfo.Name[:], utunControlName)
 
 
-	copy(ctlInfo.ctlName[:], utunControlName)
-
-	err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo)))
+	err = unix.IoctlCtlInfo(fd, ctlInfo)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
 		return nil, fmt.Errorf("CTLIOCGINFO: %v", err)
 	}
 	}
 
 
-	sc := sockaddrCtl{
-		scLen:     uint8(sockaddrCtlSize),
-		scFamily:  unix.AF_SYSTEM,
-		ssSysaddr: _AF_SYS_CONTROL,
-		scID:      ctlInfo.ctlID,
-		scUnit:    uint32(ifIndex) + 1,
-	}
-
-	_, _, errno := unix.RawSyscall(
-		unix.SYS_CONNECT,
-		uintptr(fd),
-		uintptr(unsafe.Pointer(&sc)),
-		sockaddrCtlSize,
-	)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
+	err = unix.Connect(fd, &unix.SockaddrCtl{
+		ID:   ctlInfo.Id,
+		Unit: uint32(ifIndex) + 1,
+	})
+	if err != nil {
+		return nil, fmt.Errorf("SYS_CONNECT: %v", err)
 	}
 	}
 
 
-	var ifName struct {
-		name [16]byte
-	}
-	ifNameSize := uintptr(len(ifName.name))
-	_, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd),
-		2, // SYSPROTO_CONTROL
-		2, // UTUN_OPT_IFNAME
-		uintptr(unsafe.Pointer(&ifName)),
-		uintptr(unsafe.Pointer(&ifNameSize)), 0)
-	if errno != 0 {
-		return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
+	name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME)
+	if err != nil {
+		return nil, fmt.Errorf("failed to retrieve tun name: %w", err)
 	}
 	}
-	name = string(ifName.name[:ifNameSize-1])
 
 
-	err = syscall.SetNonblock(fd, true)
+	err = unix.SetNonblock(fd, true)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("SetNonblock: %v", err)
 		return nil, fmt.Errorf("SetNonblock: %v", err)
 	}
 	}
 
 
-	file := os.NewFile(uintptr(fd), "")
-
 	t := &tun{
 	t := &tun{
-		ReadWriteCloser: file,
+		ReadWriteCloser: os.NewFile(uintptr(fd), ""),
 		Device:          name,
 		Device:          name,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		DefaultMTU:      c.GetInt("tun.mtu", DefaultMTU),
 		DefaultMTU:      c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 }
 
 
@@ -186,16 +167,6 @@ func (t *tun) Close() error {
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 	devName := t.deviceBytes()
 
 
-	var addr, mask [4]byte
-
-	if !t.cidr.Addr().Is4() {
-		//TODO: IPV6-WORK
-		panic("need ipv6")
-	}
-
-	addr = t.cidr.Addr().As4()
-	copy(mask[:], prefixToMask(t.cidr))
-
 	s, err := unix.Socket(
 	s, err := unix.Socket(
 		unix.AF_INET,
 		unix.AF_INET,
 		unix.SOCK_DGRAM,
 		unix.SOCK_DGRAM,
@@ -208,66 +179,18 @@ func (t *tun) Activate() error {
 
 
 	fd := uintptr(s)
 	fd := uintptr(s)
 
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
-	// Set the device name
-	ifrf := ifReq{Name: devName}
-	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to set tun device name: %s", err)
-	}
-
 	// Set the MTU on the device
 	// Set the MTU on the device
 	ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
 	ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)}
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 	if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil {
 		return fmt.Errorf("failed to set tun mtu: %v", err)
 		return fmt.Errorf("failed to set tun mtu: %v", err)
 	}
 	}
 
 
-	/*
-		// Set the transmit queue length
-		ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)}
-		if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil {
-			// If we can't set the queue length nebula will still work but it may lead to packet loss
-			l.WithError(err).Error("Failed to set tun tx queue length")
-		}
-	*/
-
-	// Bring up the interface
-	ifrf.Flags = ifrf.Flags | unix.IFF_UP
-	if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to bring the tun device up: %s", err)
-	}
-
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	// Get the device flags
+	ifrf := ifReq{Name: devName}
+	if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+		return fmt.Errorf("failed to get tun flags: %s", err)
 	}
 	}
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
 
 
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	linkAddr, err := getLinkAddr(t.Device)
 	linkAddr, err := getLinkAddr(t.Device)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -277,14 +200,18 @@ func (t *tun) Activate() error {
 	}
 	}
 	t.linkAddr = linkAddr
 	t.linkAddr = linkAddr
 
 
-	copy(routeAddr.IP[:], addr[:])
-	copy(maskAddr.IP[:], mask[:])
-	err = addRoute(routeSock, routeAddr, maskAddr, linkAddr)
-	if err != nil {
-		if errors.Is(err, unix.EEXIST) {
-			err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr)
+	for _, network := range t.vpnNetworks {
+		if network.Addr().Is4() {
+			err = t.activate4(network)
+			if err != nil {
+				return err
+			}
+		} else {
+			err = t.activate6(network)
+			if err != nil {
+				return err
+			}
 		}
 		}
-		return err
 	}
 	}
 
 
 	// Run the interface
 	// Run the interface
@@ -297,8 +224,89 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) activate4(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias4{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		DstAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   network.Addr().As4(),
+		},
+		MaskAddr: unix.RawSockaddrInet4{
+			Len:    unix.SizeofSockaddrInet4,
+			Family: unix.AF_INET,
+			Addr:   prefixToMask(network).As4(),
+		},
+	}
+
+	if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun v4 address: %s", err)
+	}
+
+	err = addRoute(network, t.linkAddr)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (t *tun) activate6(network netip.Prefix) error {
+	s, err := unix.Socket(
+		unix.AF_INET6,
+		unix.SOCK_DGRAM,
+		unix.IPPROTO_IP,
+	)
+	if err != nil {
+		return err
+	}
+	defer unix.Close(s)
+
+	ifr := ifreqAlias6{
+		Name: t.deviceBytes(),
+		Addr: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   network.Addr().As16(),
+		},
+		PrefixMask: unix.RawSockaddrInet6{
+			Len:    unix.SizeofSockaddrInet6,
+			Family: unix.AF_INET6,
+			Addr:   prefixToMask(network).As16(),
+		},
+		Lifetime: addrLifetime{
+			// never expires
+			Vltime: 0xffffffff,
+			Pltime: 0xffffffff,
+		},
+		//TODO: should we disable DAD (duplicate address detection) and mark this as a secured address?
+		Flags: _IN6_IFF_NODAD,
+	}
+
+	if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil {
+		return fmt.Errorf("failed to set tun address: %s", err)
+	}
+
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) {
 }
 }
 
 
 func (t *tun) addRoutes(logErrors bool) error {
 func (t *tun) addRoutes(logErrors bool) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
+
 	for _, r := range routes {
 	for _, r := range routes {
 		if !r.Via.IsValid() || !r.Install {
 		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
 
 
-		if !r.Cidr.Addr().Is4() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		//TODO: we could avoid the copy
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := addRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 		if err != nil {
 			if errors.Is(err, unix.EEXIST) {
 			if errors.Is(err, unix.EEXIST) {
 				t.l.WithField("route", r.Cidr).
 				t.l.WithField("route", r.Cidr).
@@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error {
 }
 }
 
 
 func (t *tun) removeRoutes(routes []Route) error {
 func (t *tun) removeRoutes(routes []Route) error {
-	routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
-	if err != nil {
-		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
-	}
-
-	defer func() {
-		unix.Shutdown(routeSock, unix.SHUT_RDWR)
-		err := unix.Close(routeSock)
-		if err != nil {
-			t.l.WithError(err).Error("failed to close AF_ROUTE socket")
-		}
-	}()
-
-	routeAddr := &netroute.Inet4Addr{}
-	maskAddr := &netroute.Inet4Addr{}
-
 	for _, r := range routes {
 	for _, r := range routes {
 		if !r.Install {
 		if !r.Install {
 			continue
 			continue
 		}
 		}
 
 
-		if r.Cidr.Addr().Is6() {
-			//TODO: implement ipv6
-			panic("Cant handle ipv6 routes yet")
-		}
-
-		routeAddr.IP = r.Cidr.Addr().As4()
-		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
-
-		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
+		err := delRoute(r.Cidr, t.linkAddr)
 		if err != nil {
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 		} else {
@@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 	return nil
 }
 }
 
 
-func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func addRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := &netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_ADD,
 		Type:    unix.RTM_ADD,
 		Flags:   unix.RTF_UP,
 		Flags:   unix.RTF_UP,
 		Seq:     1,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 	}
 
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
 	}
+
 	_, err = unix.Write(sock, data[:])
 	_, err = unix.Write(sock, data[:])
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
 		return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err)
@@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 	return nil
 	return nil
 }
 }
 
 
-func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error {
-	r := netroute.RouteMessage{
+func delRoute(prefix netip.Prefix, gateway netroute.Addr) error {
+	sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+	if err != nil {
+		return fmt.Errorf("unable to create AF_ROUTE socket: %v", err)
+	}
+	defer unix.Close(sock)
+
+	route := netroute.RouteMessage{
 		Version: unix.RTM_VERSION,
 		Version: unix.RTM_VERSION,
 		Type:    unix.RTM_DELETE,
 		Type:    unix.RTM_DELETE,
 		Seq:     1,
 		Seq:     1,
-		Addrs: []netroute.Addr{
-			unix.RTAX_DST:     addr,
-			unix.RTAX_GATEWAY: link,
-			unix.RTAX_NETMASK: mask,
-		},
 	}
 	}
 
 
-	data, err := r.Marshal()
+	if prefix.Addr().Is4() {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()},
+			unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	} else {
+		route.Addrs = []netroute.Addr{
+			unix.RTAX_DST:     &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()},
+			unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()},
+			unix.RTAX_GATEWAY: gateway,
+		}
+	}
+
+	data, err := route.Marshal()
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 		return fmt.Errorf("failed to create route.RouteMessage: %w", err)
 	}
 	}
@@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr)
 }
 }
 
 
 func (t *tun) Read(to []byte) (int, error) {
 func (t *tun) Read(to []byte) (int, error) {
-
 	buf := make([]byte, len(to)+4)
 	buf := make([]byte, len(to)+4)
 
 
 	n, err := t.ReadWriteCloser.Read(buf)
 	n, err := t.ReadWriteCloser.Read(buf)
@@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) {
 	return n - 4, err
 	return n - 4, err
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {
@@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
 }
 
 
-func prefixToMask(prefix netip.Prefix) []byte {
+func prefixToMask(prefix netip.Prefix) netip.Addr {
 	pLen := 128
 	pLen := 128
 	if prefix.Addr().Is4() {
 	if prefix.Addr().Is4() {
 		pLen = 32
 		pLen = 32
 	}
 	}
-	return net.CIDRMask(prefix.Bits(), pLen)
+
+	addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen))
+	return addr
 }
 }

+ 8 - 8
overlay/tun_disabled.go

@@ -12,8 +12,8 @@ import (
 )
 )
 
 
 type disabledTun struct {
 type disabledTun struct {
-	read chan []byte
-	cidr netip.Prefix
+	read        chan []byte
+	vpnNetworks []netip.Prefix
 
 
 	// Track these metrics since we don't have the tun device to do it for us
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
 	tx metrics.Counter
@@ -21,11 +21,11 @@ type disabledTun struct {
 	l  *logrus.Logger
 	l  *logrus.Logger
 }
 }
 
 
-func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
 	tun := &disabledTun{
-		cidr: cidr,
-		read: make(chan []byte, queueLen),
-		l:    l,
+		vpnNetworks: vpnNetworks,
+		read:        make(chan []byte, queueLen),
+		l:           l,
 	}
 	}
 
 
 	if metricsEnabled {
 	if metricsEnabled {
@@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
 	return netip.Addr{}
 	return netip.Addr{}
 }
 }
 
 
-func (t *disabledTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *disabledTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (*disabledTun) Name() string {
 func (*disabledTun) Name() string {

+ 25 - 15
overlay/tun_freebsd.go

@@ -46,12 +46,12 @@ type ifreqDestroy struct {
 }
 }
 
 
 type tun struct {
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 }
 }
@@ -78,11 +78,11 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
 	// Try to open existing tun device
 	var file *os.File
 	var file *os.File
 	var err error
 	var err error
@@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		Device:          deviceName,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 	return t, nil
 }
 }
 
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	var err error
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device)
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device)
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -195,8 +195,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 	return r
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 10 - 10
overlay/tun_ios.go

@@ -21,20 +21,20 @@ import (
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	cidr      netip.Prefix
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	vpnNetworks []netip.Prefix
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 }
 }
 
 
-func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	t := &tun{
 	t := &tun{
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		ReadWriteCloser: &tunReadCloser{f: file},
 		ReadWriteCloser: &tunReadCloser{f: file},
 		l:               l,
 		l:               l,
 	}
 	}
@@ -59,7 +59,7 @@ func (t *tun) Activate() error {
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error {
 	return tr.f.Close()
 	return tr.f.Close()
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 108 - 69
overlay/tun_linux.go

@@ -25,7 +25,7 @@ type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	fd          int
 	fd          int
 	Device      string
 	Device      string
-	cidr        netip.Prefix
+	vpnNetworks []netip.Prefix
 	MaxMTU      int
 	MaxMTU      int
 	DefaultMTU  int
 	DefaultMTU  int
 	TXQueueLen  int
 	TXQueueLen  int
@@ -40,18 +40,16 @@ type tun struct {
 	l *logrus.Logger
 	l *logrus.Logger
 }
 }
 
 
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
+
 type ifReq struct {
 type ifReq struct {
 	Name  [16]byte
 	Name  [16]byte
 	Flags uint16
 	Flags uint16
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-type ifreqAddr struct {
-	Name [16]byte
-	Addr unix.RawSockaddrInet4
-	pad  [8]byte
-}
-
 type ifreqMTU struct {
 type ifreqMTU struct {
 	Name [16]byte
 	Name [16]byte
 	MTU  int32
 	MTU  int32
@@ -64,10 +62,10 @@ type ifreqQLEN struct {
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 	if err != nil {
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	name := strings.Trim(string(req.Name[:]), "\x00")
 	name := strings.Trim(string(req.Name[:]), "\x00")
 
 
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
 	file := os.NewFile(uintptr(fd), "/dev/net/tun")
-	t, err := newTunGeneric(c, l, file, cidr)
+	t, err := newTunGeneric(c, l, file, vpnNetworks)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
 		TXQueueLen:      c.GetInt("tun.tx_queue", 500),
 		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
 		useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
 		l:               l,
 		l:               l,
@@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -190,11 +188,13 @@ func (t *tun) reload(c *config.C, initial bool) error {
 		}
 		}
 
 
 		if oldDefaultMTU != newDefaultMTU {
 		if oldDefaultMTU != newDefaultMTU {
-			err := t.setDefaultRoute()
-			if err != nil {
-				t.l.Warn(err)
-			} else {
-				t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+			for i := range t.vpnNetworks {
+				err := t.setDefaultRoute(t.vpnNetworks[i])
+				if err != nil {
+					t.l.Warn(err)
+				} else {
+					t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU)
+				}
 			}
 			}
 		}
 		}
 
 
@@ -237,10 +237,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 
 
 func (t *tun) Write(b []byte) (int, error) {
 func (t *tun) Write(b []byte) (int, error) {
 	var nn int
 	var nn int
-	max := len(b)
+	maximum := len(b)
 
 
 	for {
 	for {
-		n, err := unix.Write(t.fd, b[nn:max])
+		n, err := unix.Write(t.fd, b[nn:maximum])
 		if n > 0 {
 		if n > 0 {
 			nn += n
 			nn += n
 		}
 		}
@@ -265,6 +265,58 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
+func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool {
+	for i := range al {
+		if al[i].Equal(x) {
+			return true
+		}
+	}
+	return false
+}
+
+// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there
+func (t *tun) addIPs(link netlink.Link) error {
+	newAddrs := make([]*netlink.Addr, len(t.vpnNetworks))
+	for i := range t.vpnNetworks {
+		newAddrs[i] = &netlink.Addr{
+			IPNet: &net.IPNet{
+				IP:   t.vpnNetworks[i].Addr().AsSlice(),
+				Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()),
+			},
+			Label: t.vpnNetworks[i].Addr().Zone(),
+		}
+	}
+
+	//add all new addresses
+	for i := range newAddrs {
+		//todo do we want to stack errors and try as many ops as possible?
+		//AddrReplace still adds new IPs, but if their properties change it will change them as well
+		if err := netlink.AddrReplace(link, newAddrs[i]); err != nil {
+			return err
+		}
+	}
+
+	//iterate over remainder, remove whoever shouldn't be there
+	al, err := netlink.AddrList(link, netlink.FAMILY_ALL)
+	if err != nil {
+		return fmt.Errorf("failed to get tun address list: %s", err)
+	}
+
+	for i := range al {
+		if hasNetlinkAddr(newAddrs, al[i]) {
+			continue
+		}
+		err = netlink.AddrDel(link, &al[i])
+		if err != nil {
+			t.l.WithError(err).Error("failed to remove address from tun address list")
+		} else {
+			t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)")
+		}
+	}
+
+	return nil
+}
+
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
 	devName := t.deviceBytes()
 	devName := t.deviceBytes()
 
 
@@ -272,15 +324,8 @@ func (t *tun) Activate() error {
 		t.watchRoutes()
 		t.watchRoutes()
 	}
 	}
 
 
-	var addr, mask [4]byte
-
-	//TODO: IPV6-WORK
-	addr = t.cidr.Addr().As4()
-	tmask := net.CIDRMask(t.cidr.Bits(), 32)
-	copy(mask[:], tmask)
-
 	s, err := unix.Socket(
 	s, err := unix.Socket(
-		unix.AF_INET,
+		unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine
 		unix.SOCK_DGRAM,
 		unix.SOCK_DGRAM,
 		unix.IPPROTO_IP,
 		unix.IPPROTO_IP,
 	)
 	)
@@ -289,31 +334,19 @@ func (t *tun) Activate() error {
 	}
 	}
 	t.ioctlFd = uintptr(s)
 	t.ioctlFd = uintptr(s)
 
 
-	ifra := ifreqAddr{
-		Name: devName,
-		Addr: unix.RawSockaddrInet4{
-			Family: unix.AF_INET,
-			Addr:   addr,
-		},
-	}
-
-	// Set the device ip address
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun address: %s", err)
-	}
-
-	// Set the device network
-	ifra.Addr.Addr = mask
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
-		return fmt.Errorf("failed to set tun netmask: %s", err)
-	}
-
 	// Set the device name
 	// Set the device name
 	ifrf := ifReq{Name: devName}
 	ifrf := ifReq{Name: devName}
 	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 	if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to set tun device name: %s", err)
 		return fmt.Errorf("failed to set tun device name: %s", err)
 	}
 	}
 
 
+	link, err := netlink.LinkByName(t.Device)
+	if err != nil {
+		return fmt.Errorf("failed to get tun device link: %s", err)
+	}
+
+	t.deviceIndex = link.Attrs().Index
+
 	// Setup our default MTU
 	// Setup our default MTU
 	t.setMTU()
 	t.setMTU()
 
 
@@ -324,20 +357,27 @@ func (t *tun) Activate() error {
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 		t.l.WithError(err).Error("Failed to set tun tx queue length")
 	}
 	}
 
 
+	if err = t.addIPs(link); err != nil {
+		return err
+	}
+
 	// Bring up the interface
 	// Bring up the interface
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP
 	ifrf.Flags = ifrf.Flags | unix.IFF_UP
 	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
 		return fmt.Errorf("failed to bring the tun device up: %s", err)
 		return fmt.Errorf("failed to bring the tun device up: %s", err)
 	}
 	}
 
 
-	link, err := netlink.LinkByName(t.Device)
-	if err != nil {
-		return fmt.Errorf("failed to get tun device link: %s", err)
+	// Run the interface
+	ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
+	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
+		return fmt.Errorf("failed to run tun device: %s", err)
 	}
 	}
-	t.deviceIndex = link.Attrs().Index
 
 
-	if err = t.setDefaultRoute(); err != nil {
-		return err
+	//set route MTU
+	for i := range t.vpnNetworks {
+		if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
+			return fmt.Errorf("failed to set default route MTU: %w", err)
+		}
 	}
 	}
 
 
 	// Set the routes
 	// Set the routes
@@ -345,11 +385,7 @@ func (t *tun) Activate() error {
 		return err
 		return err
 	}
 	}
 
 
-	// Run the interface
-	ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
-	if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil {
-		return fmt.Errorf("failed to run tun device: %s", err)
-	}
+	//todo do we want to keep the link-local address?
 
 
 	return nil
 	return nil
 }
 }
@@ -363,12 +399,12 @@ func (t *tun) setMTU() {
 	}
 	}
 }
 }
 
 
-func (t *tun) setDefaultRoute() error {
+func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
 	// Default route
 	// Default route
 
 
 	dr := &net.IPNet{
 	dr := &net.IPNet{
-		IP:   t.cidr.Masked().Addr().AsSlice(),
-		Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+		IP:   cidr.Masked().Addr().AsSlice(),
+		Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()),
 	}
 	}
 
 
 	nr := netlink.Route{
 	nr := netlink.Route{
@@ -377,7 +413,7 @@ func (t *tun) setDefaultRoute() error {
 		MTU:       t.DefaultMTU,
 		MTU:       t.DefaultMTU,
 		AdvMSS:    t.advMSS(Route{}),
 		AdvMSS:    t.advMSS(Route{}),
 		Scope:     unix.RT_SCOPE_LINK,
 		Scope:     unix.RT_SCOPE_LINK,
-		Src:       net.IP(t.cidr.Addr().AsSlice()),
+		Src:       net.IP(cidr.Addr().AsSlice()),
 		Protocol:  unix.RTPROT_KERNEL,
 		Protocol:  unix.RTPROT_KERNEL,
 		Table:     unix.RT_TABLE_MAIN,
 		Table:     unix.RT_TABLE_MAIN,
 		Type:      unix.RTN_UNICAST,
 		Type:      unix.RTN_UNICAST,
@@ -463,10 +499,6 @@ func (t *tun) removeRoutes(routes []Route) {
 	}
 	}
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
 func (t *tun) Name() string {
 func (t *tun) Name() string {
 	return t.Device
 	return t.Device
 }
 }
@@ -523,9 +555,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	}
 	}
 
 
 	gwAddr = gwAddr.Unmap()
 	gwAddr = gwAddr.Unmap()
-	if !t.cidr.Contains(gwAddr) {
+	withinNetworks := false
+	for i := range t.vpnNetworks {
+		if t.vpnNetworks[i].Contains(gwAddr) {
+			withinNetworks = true
+			break
+		}
+	}
+	if !withinNetworks {
 		// Gateway isn't in our overlay network, ignore
 		// Gateway isn't in our overlay network, ignore
-		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
+		t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
 		return
 		return
 	}
 	}
 
 

+ 28 - 17
overlay/tun_netbsd.go

@@ -27,12 +27,12 @@ type ifreqDestroy struct {
 }
 }
 
 
 type tun struct {
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 }
 }
@@ -58,13 +58,13 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 }
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
 	// Try to open tun device
 	var file *os.File
 	var file *os.File
 	var err error
 	var err error
@@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		Device:          deviceName,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	return t, nil
 	return t, nil
 }
 }
 
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	var err error
 
 
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -130,8 +130,18 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	return r
 	return r
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {
@@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
+		//todo is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

+ 29 - 19
overlay/tun_openbsd.go

@@ -21,12 +21,12 @@ import (
 )
 )
 
 
 type tun struct {
 type tun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 
 
@@ -42,13 +42,13 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 }
 }
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
 	if deviceName == "" {
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		Device:          deviceName,
 		Device:          deviceName,
-		cidr:            cidr,
+		vpnNetworks:     vpnNetworks,
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		MTU:             c.GetInt("tun.mtu", DefaultMTU),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err
 }
 }
 
 
 func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Activate() error {
+func (t *tun) addIp(cidr netip.Prefix) error {
 	var err error
 	var err error
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -138,7 +138,7 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -148,6 +148,16 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
+func (t *tun) Activate() error {
+	for i := range t.vpnNetworks {
+		err := t.addIp(t.vpnNetworks[i])
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
 	r, _ := t.routeTree.Load().Lookup(ip)
 	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
@@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//todo is this right?
+		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 		if !r.Install {
 		if !r.Install {
 			continue
 			continue
 		}
 		}
-
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
+		//todo is this right?
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *tun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *tun) Name() string {
 func (t *tun) Name() string {

+ 17 - 17
overlay/tun_tester.go

@@ -16,19 +16,19 @@ import (
 )
 )
 
 
 type TestTun struct {
 type TestTun struct {
-	Device    string
-	cidr      netip.Prefix
-	Routes    []Route
-	routeTree *bart.Table[netip.Addr]
-	l         *logrus.Logger
+	Device      string
+	vpnNetworks []netip.Prefix
+	Routes      []Route
+	routeTree   *bart.Table[netip.Addr]
+	l           *logrus.Logger
 
 
 	closed    atomic.Bool
 	closed    atomic.Bool
 	rxPackets chan []byte // Packets to receive into nebula
 	rxPackets chan []byte // Packets to receive into nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
-	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
+	_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
@@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun,
 	}
 	}
 
 
 	return &TestTun{
 	return &TestTun{
-		Device:    c.GetString("tun.dev", ""),
-		cidr:      cidr,
-		Routes:    routes,
-		routeTree: routeTree,
-		l:         l,
-		rxPackets: make(chan []byte, 10),
-		TxPackets: make(chan []byte, 10),
+		Device:      c.GetString("tun.dev", ""),
+		vpnNetworks: vpnNetworks,
+		Routes:      routes,
+		routeTree:   routeTree,
+		l:           l,
+		rxPackets:   make(chan []byte, 10),
+		TxPackets:   make(chan []byte, 10),
 	}, nil
 	}, nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 }
 
 
@@ -95,8 +95,8 @@ func (t *TestTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *TestTun) Cidr() netip.Prefix {
-	return t.cidr
+func (t *TestTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
 }
 }
 
 
 func (t *TestTun) Name() string {
 func (t *TestTun) Name() string {

+ 0 - 208
overlay/tun_water_windows.go

@@ -1,208 +0,0 @@
-package overlay
-
-import (
-	"fmt"
-	"io"
-	"net"
-	"net/netip"
-	"os/exec"
-	"strconv"
-	"sync/atomic"
-
-	"github.com/gaissmai/bart"
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
-	"github.com/songgao/water"
-)
-
-type waterTun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
-	f         *net.Interface
-	*water.Interface
-}
-
-func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
-	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
-	t := &waterTun{
-		cidr: cidr,
-		MTU:  c.GetInt("tun.mtu", DefaultMTU),
-		l:    l,
-	}
-
-	err := t.reload(c, true)
-	if err != nil {
-		return nil, err
-	}
-
-	c.RegisterReloadCallback(func(c *config.C) {
-		err := t.reload(c, false)
-		if err != nil {
-			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
-		}
-	})
-
-	return t, nil
-}
-
-func (t *waterTun) Activate() error {
-	var err error
-	t.Interface, err = water.New(water.Config{
-		DeviceType: water.TUN,
-		PlatformSpecificParams: water.PlatformSpecificParams{
-			ComponentID: "tap0901",
-			Network:     t.cidr.String(),
-		},
-	})
-	if err != nil {
-		return fmt.Errorf("activate failed: %v", err)
-	}
-
-	t.Device = t.Interface.Name()
-
-	// TODO use syscalls instead of exec.Command
-	err = exec.Command(
-		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
-		fmt.Sprintf("name=%s", t.Device),
-		"source=static",
-		fmt.Sprintf("addr=%s", t.cidr.Addr()),
-		fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
-		"gateway=none",
-	).Run()
-	if err != nil {
-		return fmt.Errorf("failed to run 'netsh' to set address: %s", err)
-	}
-	err = exec.Command(
-		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface",
-		t.Device,
-		fmt.Sprintf("mtu=%d", t.MTU),
-	).Run()
-	if err != nil {
-		return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err)
-	}
-
-	t.f, err = net.InterfaceByName(t.Device)
-	if err != nil {
-		return fmt.Errorf("failed to find interface named %s: %v", t.Device, err)
-	}
-
-	err = t.addRoutes(false)
-	if err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (t *waterTun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
-	if err != nil {
-		return err
-	}
-
-	if !initial && !change {
-		return nil
-	}
-
-	routeTree, err := makeRouteTree(t.l, routes, false)
-	if err != nil {
-		return err
-	}
-
-	// Teach nebula how to handle the routes before establishing them in the system table
-	oldRoutes := t.Routes.Swap(&routes)
-	t.routeTree.Store(routeTree)
-
-	if !initial {
-		// Remove first, if the system removes a wanted route hopefully it will be re-added next
-		t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
-
-		// Ensure any routes we actually want are installed
-		err = t.addRoutes(true)
-		if err != nil {
-			// Catch any stray logs
-			util.LogWithContextIfNeeded("Failed to set routes", err, t.l)
-		} else {
-			for _, r := range findRemovedRoutes(routes, *oldRoutes) {
-				t.l.WithField("route", r).Info("Removed route")
-			}
-		}
-	}
-
-	return nil
-}
-
-func (t *waterTun) addRoutes(logErrors bool) error {
-	// Path routes
-	routes := *t.Routes.Load()
-	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
-			// We don't allow route MTUs so only install routes with a via
-			continue
-		}
-
-		err := exec.Command(
-			"C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
-		).Run()
-
-		if err != nil {
-			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
-			if logErrors {
-				retErr.Log(t.l)
-			} else {
-				return retErr
-			}
-		} else {
-			t.l.WithField("route", r).Info("Added route")
-		}
-	}
-
-	return nil
-}
-
-func (t *waterTun) removeRoutes(routes []Route) {
-	for _, r := range routes {
-		if !r.Install {
-			continue
-		}
-
-		err := exec.Command(
-			"C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric),
-		).Run()
-		if err != nil {
-			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
-		} else {
-			t.l.WithField("route", r).Info("Removed route")
-		}
-	}
-}
-
-func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
-	r, _ := t.routeTree.Load().Lookup(ip)
-	return r
-}
-
-func (t *waterTun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
-func (t *waterTun) Name() string {
-	return t.Device
-}
-
-func (t *waterTun) Close() error {
-	if t.Interface == nil {
-		return nil
-	}
-
-	return t.Interface.Close()
-}
-
-func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
-	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
-}

+ 239 - 13
overlay/tun_windows.go

@@ -4,41 +4,267 @@
 package overlay
 package overlay
 
 
 import (
 import (
+	"crypto"
 	"fmt"
 	"fmt"
+	"io"
 	"net/netip"
 	"net/netip"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
+	"sync/atomic"
 	"syscall"
 	"syscall"
+	"unsafe"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
+	"github.com/slackhq/nebula/wintun"
+	"golang.org/x/sys/windows"
+	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
 )
 )
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
+const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
+
+type winTun struct {
+	Device      string
+	vpnNetworks []netip.Prefix
+	MTU         int
+	Routes      atomic.Pointer[[]Route]
+	routeTree   atomic.Pointer[bart.Table[netip.Addr]]
+	l           *logrus.Logger
+
+	tun *wintun.NativeTun
+}
+
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
-	useWintun := true
-	if err := checkWinTunExists(); err != nil {
-		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
-		useWintun = false
+func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
+	err := checkWinTunExists()
+	if err != nil {
+		return nil, fmt.Errorf("can not load the wintun driver: %w", err)
+	}
+
+	deviceName := c.GetString("tun.dev", "")
+	guid, err := generateGUIDByDeviceName(deviceName)
+	if err != nil {
+		return nil, fmt.Errorf("generate GUID failed: %w", err)
+	}
+
+	t := &winTun{
+		Device:      deviceName,
+		vpnNetworks: vpnNetworks,
+		MTU:         c.GetInt("tun.mtu", DefaultMTU),
+		l:           l,
+	}
+
+	err = t.reload(c, true)
+	if err != nil {
+		return nil, err
+	}
+
+	var tunDevice wintun.Device
+	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
+	if err != nil {
+		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
+		// Trying a second time resolves the issue.
+		l.WithError(err).Debug("Failed to create wintun device, retrying")
+		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
+		if err != nil {
+			return nil, fmt.Errorf("create TUN device failed: %w", err)
+		}
+	}
+	t.tun = tunDevice.(*wintun.NativeTun)
+
+	c.RegisterReloadCallback(func(c *config.C) {
+		err := t.reload(c, false)
+		if err != nil {
+			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
+		}
+	})
+
+	return t, nil
+}
+
+func (t *winTun) reload(c *config.C, initial bool) error {
+	change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial)
+	if err != nil {
+		return err
+	}
+
+	if !initial && !change {
+		return nil
+	}
+
+	routeTree, err := makeRouteTree(t.l, routes, false)
+	if err != nil {
+		return err
 	}
 	}
 
 
-	if useWintun {
-		device, err := newWinTun(c, l, cidr, multiqueue)
+	// Teach nebula how to handle the routes before establishing them in the system table
+	oldRoutes := t.Routes.Swap(&routes)
+	t.routeTree.Store(routeTree)
+
+	if !initial {
+		// Remove first, if the system removes a wanted route hopefully it will be re-added next
+		err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
+		if err != nil {
+			util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
+		}
+
+		// Ensure any routes we actually want are installed
+		err = t.addRoutes(true)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("create Wintun interface failed, %w", err)
+			// Catch any stray logs
+			util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
 		}
 		}
-		return device, nil
 	}
 	}
 
 
-	device, err := newWaterTun(c, l, cidr, multiqueue)
+	return nil
+}
+
+func (t *winTun) Activate() error {
+	luid := winipcfg.LUID(t.tun.LUID())
+
+	err := luid.SetIPAddresses(t.vpnNetworks)
+	if err != nil {
+		return fmt.Errorf("failed to set address: %w", err)
+	}
+
+	err = t.addRoutes(false)
 	if err != nil {
 	if err != nil {
-		return nil, fmt.Errorf("create wintap driver failed, %w", err)
+		return err
 	}
 	}
-	return device, nil
+
+	return nil
+}
+
+func (t *winTun) addRoutes(logErrors bool) error {
+	luid := winipcfg.LUID(t.tun.LUID())
+	routes := *t.Routes.Load()
+	foundDefault4 := false
+
+	for _, r := range routes {
+		if !r.Via.IsValid() || !r.Install {
+			// We don't allow route MTUs so only install routes with a via
+			continue
+		}
+
+		// Add our unsafe route
+		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
+		if err != nil {
+			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
+			if logErrors {
+				retErr.Log(t.l)
+				continue
+			} else {
+				return retErr
+			}
+		} else {
+			t.l.WithField("route", r).Info("Added route")
+		}
+
+		if !foundDefault4 {
+			if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
+				foundDefault4 = true
+			}
+		}
+	}
+
+	ipif, err := luid.IPInterface(windows.AF_INET)
+	if err != nil {
+		return fmt.Errorf("failed to get ip interface: %w", err)
+	}
+
+	ipif.NLMTU = uint32(t.MTU)
+	if foundDefault4 {
+		ipif.UseAutomaticMetric = false
+		ipif.Metric = 0
+	}
+
+	if err := ipif.Set(); err != nil {
+		return fmt.Errorf("failed to set ip interface: %w", err)
+	}
+	return nil
+}
+
+func (t *winTun) removeRoutes(routes []Route) error {
+	luid := winipcfg.LUID(t.tun.LUID())
+
+	for _, r := range routes {
+		if !r.Install {
+			continue
+		}
+
+		err := luid.DeleteRoute(r.Cidr, r.Via)
+		if err != nil {
+			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
+		} else {
+			t.l.WithField("route", r).Info("Removed route")
+		}
+	}
+	return nil
+}
+
+func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
+	return r
+}
+
+func (t *winTun) Networks() []netip.Prefix {
+	return t.vpnNetworks
+}
+
+func (t *winTun) Name() string {
+	return t.Device
+}
+
+func (t *winTun) Read(b []byte) (int, error) {
+	return t.tun.Read(b, 0)
+}
+
+func (t *winTun) Write(b []byte) (int, error) {
+	return t.tun.Write(b, 0)
+}
+
+func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
+	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
+}
+
+func (t *winTun) Close() error {
+	// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
+	// so to be certain, just remove everything before destroying.
+	luid := winipcfg.LUID(t.tun.LUID())
+	_ = luid.FlushRoutes(windows.AF_INET)
+	_ = luid.FlushIPAddresses(windows.AF_INET)
+	/* We don't support IPV6 yet
+	_ = luid.FlushRoutes(windows.AF_INET6)
+	_ = luid.FlushIPAddresses(windows.AF_INET6)
+	*/
+	_ = luid.FlushDNS(windows.AF_INET)
+
+	return t.tun.Close()
+}
+
+func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
+	// GUID is 128 bit
+	hash := crypto.MD5.New()
+
+	_, err := hash.Write([]byte(tunGUIDLabel))
+	if err != nil {
+		return nil, err
+	}
+
+	_, err = hash.Write([]byte(name))
+	if err != nil {
+		return nil, err
+	}
+
+	sum := hash.Sum(nil)
+
+	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 }
 }
 
 
 func checkWinTunExists() error {
 func checkWinTunExists() error {

+ 0 - 252
overlay/tun_wintun_windows.go

@@ -1,252 +0,0 @@
-package overlay
-
-import (
-	"crypto"
-	"fmt"
-	"io"
-	"net/netip"
-	"sync/atomic"
-	"unsafe"
-
-	"github.com/gaissmai/bart"
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/util"
-	"github.com/slackhq/nebula/wintun"
-	"golang.org/x/sys/windows"
-	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
-)
-
-const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
-
-type winTun struct {
-	Device    string
-	cidr      netip.Prefix
-	MTU       int
-	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[bart.Table[netip.Addr]]
-	l         *logrus.Logger
-
-	tun *wintun.NativeTun
-}
-
-func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
-	// GUID is 128 bit
-	hash := crypto.MD5.New()
-
-	_, err := hash.Write([]byte(tunGUIDLabel))
-	if err != nil {
-		return nil, err
-	}
-
-	_, err = hash.Write([]byte(name))
-	if err != nil {
-		return nil, err
-	}
-
-	sum := hash.Sum(nil)
-
-	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
-}
-
-func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
-	deviceName := c.GetString("tun.dev", "")
-	guid, err := generateGUIDByDeviceName(deviceName)
-	if err != nil {
-		return nil, fmt.Errorf("generate GUID failed: %w", err)
-	}
-
-	t := &winTun{
-		Device: deviceName,
-		cidr:   cidr,
-		MTU:    c.GetInt("tun.mtu", DefaultMTU),
-		l:      l,
-	}
-
-	err = t.reload(c, true)
-	if err != nil {
-		return nil, err
-	}
-
-	var tunDevice wintun.Device
-	tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
-	if err != nil {
-		// Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device.
-		// Trying a second time resolves the issue.
-		l.WithError(err).Debug("Failed to create wintun device, retrying")
-		tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU)
-		if err != nil {
-			return nil, fmt.Errorf("create TUN device failed: %w", err)
-		}
-	}
-	t.tun = tunDevice.(*wintun.NativeTun)
-
-	c.RegisterReloadCallback(func(c *config.C) {
-		err := t.reload(c, false)
-		if err != nil {
-			util.LogWithContextIfNeeded("failed to reload tun device", err, t.l)
-		}
-	})
-
-	return t, nil
-}
-
-func (t *winTun) reload(c *config.C, initial bool) error {
-	change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial)
-	if err != nil {
-		return err
-	}
-
-	if !initial && !change {
-		return nil
-	}
-
-	routeTree, err := makeRouteTree(t.l, routes, false)
-	if err != nil {
-		return err
-	}
-
-	// Teach nebula how to handle the routes before establishing them in the system table
-	oldRoutes := t.Routes.Swap(&routes)
-	t.routeTree.Store(routeTree)
-
-	if !initial {
-		// Remove first, if the system removes a wanted route hopefully it will be re-added next
-		err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes))
-		if err != nil {
-			util.LogWithContextIfNeeded("Failed to remove routes", err, t.l)
-		}
-
-		// Ensure any routes we actually want are installed
-		err = t.addRoutes(true)
-		if err != nil {
-			// Catch any stray logs
-			util.LogWithContextIfNeeded("Failed to add routes", err, t.l)
-		}
-	}
-
-	return nil
-}
-
-func (t *winTun) Activate() error {
-	luid := winipcfg.LUID(t.tun.LUID())
-
-	err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
-	if err != nil {
-		return fmt.Errorf("failed to set address: %w", err)
-	}
-
-	err = t.addRoutes(false)
-	if err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (t *winTun) addRoutes(logErrors bool) error {
-	luid := winipcfg.LUID(t.tun.LUID())
-	routes := *t.Routes.Load()
-	foundDefault4 := false
-
-	for _, r := range routes {
-		if !r.Via.IsValid() || !r.Install {
-			// We don't allow route MTUs so only install routes with a via
-			continue
-		}
-
-		// Add our unsafe route
-		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
-		if err != nil {
-			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
-			if logErrors {
-				retErr.Log(t.l)
-				continue
-			} else {
-				return retErr
-			}
-		} else {
-			t.l.WithField("route", r).Info("Added route")
-		}
-
-		if !foundDefault4 {
-			if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
-				foundDefault4 = true
-			}
-		}
-	}
-
-	ipif, err := luid.IPInterface(windows.AF_INET)
-	if err != nil {
-		return fmt.Errorf("failed to get ip interface: %w", err)
-	}
-
-	ipif.NLMTU = uint32(t.MTU)
-	if foundDefault4 {
-		ipif.UseAutomaticMetric = false
-		ipif.Metric = 0
-	}
-
-	if err := ipif.Set(); err != nil {
-		return fmt.Errorf("failed to set ip interface: %w", err)
-	}
-	return nil
-}
-
-func (t *winTun) removeRoutes(routes []Route) error {
-	luid := winipcfg.LUID(t.tun.LUID())
-
-	for _, r := range routes {
-		if !r.Install {
-			continue
-		}
-
-		err := luid.DeleteRoute(r.Cidr, r.Via)
-		if err != nil {
-			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
-		} else {
-			t.l.WithField("route", r).Info("Removed route")
-		}
-	}
-	return nil
-}
-
-func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
-	r, _ := t.routeTree.Load().Lookup(ip)
-	return r
-}
-
-func (t *winTun) Cidr() netip.Prefix {
-	return t.cidr
-}
-
-func (t *winTun) Name() string {
-	return t.Device
-}
-
-func (t *winTun) Read(b []byte) (int, error) {
-	return t.tun.Read(b, 0)
-}
-
-func (t *winTun) Write(b []byte) (int, error) {
-	return t.tun.Write(b, 0)
-}
-
-func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
-	return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
-}
-
-func (t *winTun) Close() error {
-	// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes,
-	// so to be certain, just remove everything before destroying.
-	luid := winipcfg.LUID(t.tun.LUID())
-	_ = luid.FlushRoutes(windows.AF_INET)
-	_ = luid.FlushIPAddresses(windows.AF_INET)
-	/* We don't support IPV6 yet
-	_ = luid.FlushRoutes(windows.AF_INET6)
-	_ = luid.FlushIPAddresses(windows.AF_INET6)
-	*/
-	_ = luid.FlushDNS(windows.AF_INET)
-
-	return t.tun.Close()
-}

+ 6 - 6
overlay/user.go

@@ -8,16 +8,16 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 )
 )
 
 
-func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
-	return NewUserDevice(tunCidr)
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
+	return NewUserDevice(vpnNetworks)
 }
 }
 
 
-func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
+func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
 	// these pipes guarantee each write/read will match 1:1
 	// these pipes guarantee each write/read will match 1:1
 	or, ow := io.Pipe()
 	or, ow := io.Pipe()
 	ir, iw := io.Pipe()
 	ir, iw := io.Pipe()
 	return &UserDevice{
 	return &UserDevice{
-		tunCidr:        tunCidr,
+		vpnNetworks:    vpnNetworks,
 		outboundReader: or,
 		outboundReader: or,
 		outboundWriter: ow,
 		outboundWriter: ow,
 		inboundReader:  ir,
 		inboundReader:  ir,
@@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
 }
 }
 
 
 type UserDevice struct {
 type UserDevice struct {
-	tunCidr netip.Prefix
+	vpnNetworks []netip.Prefix
 
 
 	outboundReader *io.PipeReader
 	outboundReader *io.PipeReader
 	outboundWriter *io.PipeWriter
 	outboundWriter *io.PipeWriter
@@ -38,7 +38,7 @@ type UserDevice struct {
 func (d *UserDevice) Activate() error {
 func (d *UserDevice) Activate() error {
 	return nil
 	return nil
 }
 }
-func (d *UserDevice) Cidr() netip.Prefix                { return d.tunCidr }
+func (d *UserDevice) Networks() []netip.Prefix          { return d.vpnNetworks }
 func (d *UserDevice) Name() string                      { return "faketun0" }
 func (d *UserDevice) Name() string                      { return "faketun0" }
 func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
 func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {

+ 338 - 67
pki.go

@@ -1,13 +1,19 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"encoding/binary"
+	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"net"
+	"net/netip"
 	"os"
 	"os"
+	"slices"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -21,12 +27,22 @@ type PKI struct {
 }
 }
 
 
 type CertState struct {
 type CertState struct {
-	Certificate         cert.Certificate
-	RawCertificate      []byte
-	RawCertificateNoKey []byte
-	PublicKey           []byte
-	PrivateKey          []byte
-	pkcs11Backed        bool
+	v1Cert           cert.Certificate
+	v1HandshakeBytes []byte
+
+	v2Cert           cert.Certificate
+	v2HandshakeBytes []byte
+
+	defaultVersion cert.Version
+	privateKey     []byte
+	pkcs11Backed   bool
+	cipher         string
+
+	myVpnNetworks            []netip.Prefix
+	myVpnNetworksTable       *bart.Table[struct{}]
+	myVpnAddrs               []netip.Addr
+	myVpnAddrsTable          *bart.Table[struct{}]
+	myVpnBroadcastAddrsTable *bart.Table[struct{}]
 }
 }
 
 
 func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
 func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -46,16 +62,26 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
 	return pki, nil
 	return pki, nil
 }
 }
 
 
-func (p *PKI) GetCertState() *CertState {
+func (p *PKI) GetCAPool() *cert.CAPool {
+	return p.caPool.Load()
+}
+
+func (p *PKI) getCertState() *CertState {
 	return p.cs.Load()
 	return p.cs.Load()
 }
 }
 
 
-func (p *PKI) GetCAPool() *cert.CAPool {
-	return p.caPool.Load()
+// TODO: We should remove this
+func (p *PKI) getDefaultCertificate() cert.Certificate {
+	return p.cs.Load().GetDefaultCertificate()
+}
+
+// TODO: We should remove this
+func (p *PKI) getCertificate(v cert.Version) cert.Certificate {
+	return p.cs.Load().getCertificate(v)
 }
 }
 
 
 func (p *PKI) reload(c *config.C, initial bool) error {
 func (p *PKI) reload(c *config.C, initial bool) error {
-	err := p.reloadCert(c, initial)
+	err := p.reloadCerts(c, initial)
 	if err != nil {
 	if err != nil {
 		if initial {
 		if initial {
 			return err
 			return err
@@ -74,33 +100,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
-	cs, err := newCertStateFromConfig(c)
+func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
+	newState, err := newCertStateFromConfig(c)
 	if err != nil {
 	if err != nil {
 		return util.NewContextualError("Could not load client cert", nil, err)
 		return util.NewContextualError("Could not load client cert", nil, err)
 	}
 	}
 
 
 	if !initial {
 	if !initial {
-		//TODO: include check for mask equality as well
+		currentState := p.cs.Load()
+		if newState.v1Cert != nil {
+			if currentState.v1Cert == nil {
+				return util.NewContextualError("v1 certificate was added, restart required", nil, err)
+			}
+
+			// did IP in cert change? if so, don't set
+			if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
+				return util.NewContextualError(
+					"Networks in new cert was different from old",
+					m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
+					nil,
+				)
+			}
+
+			if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
+				return util.NewContextualError(
+					"Curve in new cert was different from old",
+					m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
+					nil,
+				)
+			}
+
+		} else if currentState.v1Cert != nil {
+			//TODO: we should be able to tear this down
+			return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
+		}
+
+		if newState.v2Cert != nil {
+			if currentState.v2Cert == nil {
+				return util.NewContextualError("v2 certificate was added, restart required", nil, err)
+			}
 
 
-		// did IP in cert change? if so, don't set
-		currentCert := p.cs.Load().Certificate
-		oldIPs := currentCert.Networks()
-		newIPs := cs.Certificate.Networks()
-		if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
+			// did IP in cert change? if so, don't set
+			if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
+				return util.NewContextualError(
+					"Networks in new cert was different from old",
+					m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
+					nil,
+				)
+			}
+
+			if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+				return util.NewContextualError(
+					"Curve in new cert was different from old",
+					m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
+					nil,
+				)
+			}
+
+		} else if currentState.v2Cert != nil {
+			return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
+		}
+
+		// Cipher cant be hot swapped so just leave it at what it was before
+		newState.cipher = currentState.cipher
+
+	} else {
+		newState.cipher = c.GetString("cipher", "aes")
+		//TODO: this sucks and we should make it not a global
+		switch newState.cipher {
+		case "aes":
+			noiseEndianness = binary.BigEndian
+		case "chachapoly":
+			noiseEndianness = binary.LittleEndian
+		default:
 			return util.NewContextualError(
 			return util.NewContextualError(
-				"Networks in new cert was different from old",
-				m{"new_network": newIPs[0], "old_network": oldIPs[0]},
+				"unknown cipher",
+				m{"cipher": newState.cipher},
 				nil,
 				nil,
 			)
 			)
 		}
 		}
 	}
 	}
 
 
-	p.cs.Store(cs)
+	p.cs.Store(newState)
+
+	//TODO: newState needs a stringer that does json
 	if initial {
 	if initial {
-		p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
+		p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
 	} else {
 	} else {
-		p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
+		p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
 	}
 	}
 	return nil
 	return nil
 }
 }
@@ -116,55 +203,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
 	return nil
 	return nil
 }
 }
 
 
-func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
-	// Marshal the certificate to ensure it is valid
-	rawCertificate, err := certificate.Marshal()
-	if err != nil {
-		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
+func (cs *CertState) GetDefaultCertificate() cert.Certificate {
+	c := cs.getCertificate(cs.defaultVersion)
+	if c == nil {
+		panic("No default certificate found")
 	}
 	}
+	return c
+}
 
 
-	publicKey := certificate.PublicKey()
-	cs := &CertState{
-		RawCertificate: rawCertificate,
-		Certificate:    certificate,
-		PrivateKey:     privateKey,
-		PublicKey:      publicKey,
-		pkcs11Backed:   pkcs11backed,
+func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
+	switch v {
+	case cert.Version1:
+		return cs.v1Cert
+	case cert.Version2:
+		return cs.v2Cert
 	}
 	}
 
 
-	rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
-	if err != nil {
-		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
+	return nil
+}
+
+// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
+// Callers must check if the return []byte is nil.
+func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
+	switch v {
+	case cert.Version1:
+		return cs.v1HandshakeBytes
+	case cert.Version2:
+		return cs.v2HandshakeBytes
+	default:
+		return nil
 	}
 	}
-	cs.RawCertificateNoKey = rawCertNoKey
+}
 
 
-	return cs, nil
+func (cs *CertState) String() string {
+	b, err := cs.MarshalJSON()
+	if err != nil {
+		return fmt.Sprintf("error marshaling certificate state: %v", err)
+	}
+	return string(b)
 }
 }
 
 
-func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
-	var pemPrivateKey []byte
-	if strings.Contains(privPathOrPEM, "-----BEGIN") {
-		pemPrivateKey = []byte(privPathOrPEM)
-		privPathOrPEM = "<inline>"
-		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+func (cs *CertState) MarshalJSON() ([]byte, error) {
+	msg := []json.RawMessage{}
+	if cs.v1Cert != nil {
+		b, err := cs.v1Cert.MarshalJSON()
 		if err != nil {
 		if err != nil {
-			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+			return nil, err
 		}
 		}
-	} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
-		rawKey = []byte(privPathOrPEM)
-		return rawKey, cert.Curve_P256, true, nil
-	} else {
-		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
-		if err != nil {
-			return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
-		}
-		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+		msg = append(msg, b)
+	}
+
+	if cs.v2Cert != nil {
+		b, err := cs.v2Cert.MarshalJSON()
 		if err != nil {
 		if err != nil {
-			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+			return nil, err
 		}
 		}
+		msg = append(msg, b)
 	}
 	}
 
 
-	return
+	return json.Marshal(msg)
 }
 }
 
 
 func newCertStateFromConfig(c *config.C) (*CertState, error) {
 func newCertStateFromConfig(c *config.C) (*CertState, error) {
@@ -198,24 +295,198 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
 		}
 		}
 	}
 	}
 
 
-	nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
+	var crt, v1, v2 cert.Certificate
+	for {
+		// Load the certificate
+		crt, rawCert, err = loadCertificate(rawCert)
+		if err != nil {
+			//TODO: check error
+			return nil, err
+		}
+
+		switch crt.Version() {
+		case cert.Version1:
+			if v1 != nil {
+				return nil, fmt.Errorf("v1 certificate already found in pki.cert")
+			}
+			v1 = crt
+		case cert.Version2:
+			if v2 != nil {
+				return nil, fmt.Errorf("v2 certificate already found in pki.cert")
+			}
+			v2 = crt
+		default:
+			return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
+		}
+
+		if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
+			break
+		}
+	}
+
+	if v1 == nil && v2 == nil {
+		return nil, errors.New("no certificates found in pki.cert")
+	}
+
+	useDefaultVersion := uint32(1)
+	if v1 == nil {
+		// The only condition that requires v2 as the default is if only a v2 certificate is present
+		// We do this to avoid having to configure it specifically in the config file
+		useDefaultVersion = 2
+	}
+
+	rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
+	var defaultVersion cert.Version
+	switch rawDefaultVersion {
+	case 1:
+		if v1 == nil {
+			return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
+		}
+		defaultVersion = cert.Version1
+	case 2:
+		defaultVersion = cert.Version2
+	default:
+		return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
+	}
+
+	return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
+}
+
+func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
+	cs := CertState{
+		privateKey:               privateKey,
+		pkcs11Backed:             pkcs11backed,
+		myVpnNetworksTable:       new(bart.Table[struct{}]),
+		myVpnAddrsTable:          new(bart.Table[struct{}]),
+		myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
+	}
+
+	if v1 != nil && v2 != nil {
+		if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
+			return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
+		}
+
+		if v1.Curve() != v2.Curve() {
+			return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
+		}
+
+		//TODO: make sure v2 has v1s address
+
+		cs.defaultVersion = dv
+	}
+
+	if v1 != nil {
+		if pkcs11backed {
+			//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
+		} else {
+			if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
+				return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+			}
+		}
+
+		v1hs, err := v1.MarshalForHandshakes()
+		if err != nil {
+			return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
+		}
+		cs.v1Cert = v1
+		cs.v1HandshakeBytes = v1hs
+
+		if cs.defaultVersion == 0 {
+			cs.defaultVersion = cert.Version1
+		}
+	}
+
+	if v2 != nil {
+		if pkcs11backed {
+			//TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm
+		} else {
+			if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
+				return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+			}
+		}
+
+		v2hs, err := v2.MarshalForHandshakes()
+		if err != nil {
+			return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
+		}
+		cs.v2Cert = v2
+		cs.v2HandshakeBytes = v2hs
+
+		if cs.defaultVersion == 0 {
+			cs.defaultVersion = cert.Version2
+		}
+	}
+
+	var crt cert.Certificate
+	crt = cs.getCertificate(cert.Version2)
+	if crt == nil {
+		// v2 certificates are a superset, only look at v1 if its all we have
+		crt = cs.getCertificate(cert.Version1)
+	}
+
+	for _, network := range crt.Networks() {
+		cs.myVpnNetworks = append(cs.myVpnNetworks, network)
+		cs.myVpnNetworksTable.Insert(network, struct{}{})
+
+		cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
+		cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
+
+		if network.Addr().Is4() {
+			addr := network.Masked().Addr().As4()
+			mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
+			binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
+			cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
+		}
+	}
+
+	return &cs, nil
+}
+
+func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
+	var pemPrivateKey []byte
+	if strings.Contains(privPathOrPEM, "-----BEGIN") {
+		pemPrivateKey = []byte(privPathOrPEM)
+		privPathOrPEM = "<inline>"
+		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+		}
+	} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
+		rawKey = []byte(privPathOrPEM)
+		return rawKey, cert.Curve_P256, true, nil
+	} else {
+		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
+		}
+		rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
+		if err != nil {
+			return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+		}
+	}
+
+	return
+}
+
+func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
+	c, b, err := cert.UnmarshalCertificateFromPEM(b)
 	if err != nil {
 	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
+		return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
 	}
 	}
 
 
-	if nebulaCert.Expired(time.Now()) {
-		return nil, fmt.Errorf("nebula certificate for this host is expired")
+	if c.Expired(time.Now()) {
+		return nil, b, fmt.Errorf("nebula certificate for this host is expired")
 	}
 	}
 
 
-	if len(nebulaCert.Networks()) == 0 {
-		return nil, fmt.Errorf("no networks encoded in certificate")
+	if len(c.Networks()) == 0 {
+		return nil, b, fmt.Errorf("no networks encoded in certificate")
 	}
 	}
 
 
-	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
-		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+	if c.IsCA() {
+		return nil, b, fmt.Errorf("host certificate is a CA certificate")
 	}
 	}
 
 
-	return newCertState(nebulaCert, isPkcs11, rawKey)
+	return c, b, nil
 }
 }
 
 
 func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
 func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {

+ 122 - 63
relay_manager.go

@@ -9,6 +9,7 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 )
 )
@@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
 				Type:       relayType,
 				Type:       relayType,
 				State:      state,
 				State:      state,
 				LocalIndex: index,
 				LocalIndex: index,
-				PeerIp:     vpnIp,
+				PeerAddr:   vpnIp,
 			}
 			}
 
 
 			if remoteIdx != nil {
 			if remoteIdx != nil {
@@ -91,40 +92,60 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti
 func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
 func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) {
 	relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
 	relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex)
 	if !ok {
 	if !ok {
-		rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp,
+		//TODO: we need to handle possibly logging deprecated fields as well
+		rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnAddrs[0],
 			"initiatorRelayIndex": m.InitiatorRelayIndex,
 			"initiatorRelayIndex": m.InitiatorRelayIndex,
-			"relayFrom":           m.RelayFromIp,
-			"relayTo":             m.RelayToIp}).Info("relayManager failed to update relay")
+			"relayFrom":           m.RelayFromAddr,
+			"relayTo":             m.RelayToAddr}).Info("relayManager failed to update relay")
 		return nil, fmt.Errorf("unknown relay")
 		return nil, fmt.Errorf("unknown relay")
 	}
 	}
 
 
 	return relay, nil
 	return relay, nil
 }
 }
 
 
-func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) {
+func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
+	msg := &NebulaControl{}
+	err := msg.Unmarshal(d)
+	if err != nil {
+		h.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
+		return
+	}
+
+	var v cert.Version
+	if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 {
+		v = cert.Version1
+
+		//TODO: yeah this is junk but maybe its less junky than the other options
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr)
+		msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
 
 
-	switch m.Type {
+		binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr)
+		msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b))
+	} else {
+		v = cert.Version2
+	}
+
+	switch msg.Type {
 	case NebulaControl_CreateRelayRequest:
 	case NebulaControl_CreateRelayRequest:
-		rm.handleCreateRelayRequest(h, f, m)
+		rm.handleCreateRelayRequest(v, h, f, msg)
 	case NebulaControl_CreateRelayResponse:
 	case NebulaControl_CreateRelayResponse:
-		rm.handleCreateRelayResponse(h, f, m)
+		rm.handleCreateRelayResponse(v, h, f, msg)
 	}
 	}
 
 
 }
 }
 
 
-func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
+func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
 	rm.l.WithFields(logrus.Fields{
 	rm.l.WithFields(logrus.Fields{
-		"relayFrom":           m.RelayFromIp,
-		"relayTo":             m.RelayToIp,
+		"relayFrom":           m.RelayFromAddr,
+		"relayTo":             m.RelayToAddr,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
-		"vpnIp":               h.vpnIp}).
+		"vpnAddrs":            h.vpnAddrs}).
 		Info("handleCreateRelayResponse")
 		Info("handleCreateRelayResponse")
-	target := m.RelayToIp
-	//TODO: IPV6-WORK
-	b := [4]byte{}
-	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
-	targetAddr := netip.AddrFrom4(b)
+
+	target := m.RelayToAddr
+	targetAddr := protoAddrToNetAddr(target)
 
 
 	relay, err := rm.EstablishRelay(h, m)
 	relay, err := rm.EstablishRelay(h, m)
 	if err != nil {
 	if err != nil {
@@ -136,68 +157,79 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		return
 		return
 	}
 	}
 	// I'm the middle man. Let the initiator know that the I've established the relay they requested.
 	// I'm the middle man. Let the initiator know that the I've established the relay they requested.
-	peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
+	peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr)
 	if peerHostInfo == nil {
 	if peerHostInfo == nil {
-		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
+		rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer")
 		return
 		return
 	}
 	}
 	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
 	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
 	if !ok {
 	if !ok {
-		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
+		rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo")
 		return
 		return
 	}
 	}
 	if peerRelay.State == PeerRequested {
 	if peerRelay.State == PeerRequested {
-		//TODO: IPV6-WORK
-		b = peerHostInfo.vpnIp.As4()
 		peerRelay.State = Established
 		peerRelay.State = Established
 		resp := NebulaControl{
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: peerRelay.LocalIndex,
 			ResponderRelayIndex: peerRelay.LocalIndex,
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
-			RelayFromIp:         binary.BigEndian.Uint32(b[:]),
-			RelayToIp:           uint32(target),
 		}
 		}
+
+		if v == cert.Version1 {
+			peer := peerHostInfo.vpnAddrs[0]
+			if !peer.Is4() {
+				//TODO: log cant do it
+				return
+			}
+
+			b := peer.As4()
+			resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = targetAddr.As4()
+			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		} else {
+			resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
+			resp.RelayToAddr = target
+		}
+
 		msg, err := resp.Marshal()
 		msg, err := resp.Marshal()
 		if err != nil {
 		if err != nil {
-			rm.l.
-				WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
+			rm.l.WithError(err).
+				Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
 		} else {
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           resp.RelayFromIp,
-				"relayTo":             resp.RelayToIp,
+				"relayFrom":           resp.RelayFromAddr,
+				"relayTo":             resp.RelayToAddr,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
-				"vpnIp":               peerHostInfo.vpnIp}).
+				"vpnAddrs":            peerHostInfo.vpnAddrs}).
 				Info("send CreateRelayResponse")
 				Info("send CreateRelayResponse")
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
-	//TODO: IPV6-WORK
-	b := [4]byte{}
-	binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
-	from := netip.AddrFrom4(b)
-
-	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
-	target := netip.AddrFrom4(b)
+func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
+	from := protoAddrToNetAddr(m.RelayFromAddr)
+	target := protoAddrToNetAddr(m.RelayToAddr)
 
 
 	logMsg := rm.l.WithFields(logrus.Fields{
 	logMsg := rm.l.WithFields(logrus.Fields{
 		"relayFrom":           from,
 		"relayFrom":           from,
 		"relayTo":             target,
 		"relayTo":             target,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
-		"vpnIp":               h.vpnIp})
+		"vpnAddrs":            h.vpnAddrs})
 
 
 	logMsg.Info("handleCreateRelayRequest")
 	logMsg.Info("handleCreateRelayRequest")
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// an issue migrating relays over to newly re-handshaked host info objects.
 	// an issue migrating relays over to newly re-handshaked host info objects.
-	if from == f.myVpnNet.Addr() {
+	_, found := f.myVpnAddrsTable.Lookup(from)
+	if found {
 		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		return
 		return
 	}
 	}
+
 	// Is the target of the relay me?
 	// Is the target of the relay me?
-	if target == f.myVpnNet.Addr() {
+	_, found = f.myVpnAddrsTable.Lookup(target)
+	if found {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		if ok {
 		if ok {
 			switch existingRelay.State {
 			switch existingRelay.State {
@@ -230,17 +262,22 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			return
 			return
 		}
 		}
 
 
-		//TODO: IPV6-WORK
-		fromB := from.As4()
-		targetB := target.As4()
-
 		resp := NebulaControl{
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: relay.LocalIndex,
 			ResponderRelayIndex: relay.LocalIndex,
 			InitiatorRelayIndex: relay.RemoteIndex,
 			InitiatorRelayIndex: relay.RemoteIndex,
-			RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
-			RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 		}
 		}
+
+		if v == cert.Version1 {
+			b := from.As4()
+			resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+			b = target.As4()
+			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+		} else {
+			resp.RelayFromAddr = netAddrToProtoAddr(from)
+			resp.RelayToAddr = netAddrToProtoAddr(target)
+		}
+
 		msg, err := resp.Marshal()
 		msg, err := resp.Marshal()
 		if err != nil {
 		if err != nil {
 			logMsg.
 			logMsg.
@@ -253,7 +290,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 				"relayTo":             target,
 				"relayTo":             target,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
-				"vpnIp":               h.vpnIp}).
+				"vpnAddrs":            h.vpnAddrs}).
 				Info("send CreateRelayResponse")
 				Info("send CreateRelayResponse")
 		}
 		}
 		return
 		return
@@ -262,7 +299,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		if !rm.GetAmRelay() {
 		if !rm.GetAmRelay() {
 			return
 			return
 		}
 		}
-		peer := rm.hostmap.QueryVpnIp(target)
+		peer := rm.hostmap.QueryVpnAddr(target)
 		if peer == nil {
 		if peer == nil {
 			// Try to establish a connection to this host. If we get a future relay request,
 			// Try to establish a connection to this host. If we get a future relay request,
 			// we'll be ready!
 			// we'll be ready!
@@ -291,17 +328,27 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			sendCreateRequest = true
 			sendCreateRequest = true
 		}
 		}
 		if sendCreateRequest {
 		if sendCreateRequest {
-			//TODO: IPV6-WORK
-			fromB := h.vpnIp.As4()
-			targetB := target.As4()
-
 			// Send a CreateRelayRequest to the peer.
 			// Send a CreateRelayRequest to the peer.
 			req := NebulaControl{
 			req := NebulaControl{
 				Type:                NebulaControl_CreateRelayRequest,
 				Type:                NebulaControl_CreateRelayRequest,
 				InitiatorRelayIndex: index,
 				InitiatorRelayIndex: index,
-				RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
-				RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 			}
 			}
+
+			if v == cert.Version1 {
+				if !h.vpnAddrs[0].Is4() {
+					//TODO: log it
+					return
+				}
+
+				b := h.vpnAddrs[0].As4()
+				req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+				b = target.As4()
+				req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+			} else {
+				req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
+				req.RelayToAddr = netAddrToProtoAddr(target)
+			}
+
 			msg, err := req.Marshal()
 			msg, err := req.Marshal()
 			if err != nil {
 			if err != nil {
 				logMsg.
 				logMsg.
@@ -310,11 +357,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				rm.l.WithFields(logrus.Fields{
 				rm.l.WithFields(logrus.Fields{
 					//TODO: IPV6-WORK another lazy used to use the req object
 					//TODO: IPV6-WORK another lazy used to use the req object
-					"relayFrom":           h.vpnIp,
+					"relayFrom":           h.vpnAddrs[0],
 					"relayTo":             target,
 					"relayTo":             target,
 					"initiatorRelayIndex": req.InitiatorRelayIndex,
 					"initiatorRelayIndex": req.InitiatorRelayIndex,
 					"responderRelayIndex": req.ResponderRelayIndex,
 					"responderRelayIndex": req.ResponderRelayIndex,
-					"vpnIp":               target}).
+					"vpnAddr":             target}).
 					Info("send CreateRelayRequest")
 					Info("send CreateRelayRequest")
 			}
 			}
 		}
 		}
@@ -342,16 +389,28 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 						"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
 						"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
 					return
 					return
 				}
 				}
-				//TODO: IPV6-WORK
-				fromB := h.vpnIp.As4()
-				targetB := target.As4()
+
 				resp := NebulaControl{
 				resp := NebulaControl{
 					Type:                NebulaControl_CreateRelayResponse,
 					Type:                NebulaControl_CreateRelayResponse,
 					ResponderRelayIndex: relay.LocalIndex,
 					ResponderRelayIndex: relay.LocalIndex,
 					InitiatorRelayIndex: relay.RemoteIndex,
 					InitiatorRelayIndex: relay.RemoteIndex,
-					RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
-					RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 				}
 				}
+
+				if v == cert.Version1 {
+					if !h.vpnAddrs[0].Is4() {
+						//TODO: log it
+						return
+					}
+
+					b := h.vpnAddrs[0].As4()
+					resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+					b = target.As4()
+					resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+				} else {
+					resp.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
+					resp.RelayToAddr = netAddrToProtoAddr(target)
+				}
+
 				msg, err := resp.Marshal()
 				msg, err := resp.Marshal()
 				if err != nil {
 				if err != nil {
 					rm.l.
 					rm.l.
@@ -360,11 +419,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					rm.l.WithFields(logrus.Fields{
 					rm.l.WithFields(logrus.Fields{
 						//TODO: IPV6-WORK more lazy, used to use resp object
 						//TODO: IPV6-WORK more lazy, used to use resp object
-						"relayFrom":           h.vpnIp,
+						"relayFrom":           h.vpnAddrs[0],
 						"relayTo":             target,
 						"relayTo":             target,
 						"initiatorRelayIndex": resp.InitiatorRelayIndex,
 						"initiatorRelayIndex": resp.InitiatorRelayIndex,
 						"responderRelayIndex": resp.ResponderRelayIndex,
 						"responderRelayIndex": resp.ResponderRelayIndex,
-						"vpnIp":               h.vpnIp}).
+						"vpnAddrs":            h.vpnAddrs}).
 						Info("send CreateRelayResponse")
 						Info("send CreateRelayResponse")
 				}
 				}
 
 

+ 34 - 28
remote_list.go

@@ -17,8 +17,8 @@ import (
 type forEachFunc func(addr netip.AddrPort, preferred bool)
 type forEachFunc func(addr netip.AddrPort, preferred bool)
 
 
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool
+type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool
 
 
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // The string key is the owners vpnIp
 // The string key is the owners vpnIp
@@ -48,14 +48,14 @@ type cacheRelay struct {
 
 
 // cacheV4 stores learned and reported ipv4 records under cache
 // cacheV4 stores learned and reported ipv4 records under cache
 type cacheV4 struct {
 type cacheV4 struct {
-	learned  *Ip4AndPort
-	reported []*Ip4AndPort
+	learned  *V4AddrPort
+	reported []*V4AddrPort
 }
 }
 
 
 // cacheV4 stores learned and reported ipv6 records under cache
 // cacheV4 stores learned and reported ipv6 records under cache
 type cacheV6 struct {
 type cacheV6 struct {
-	learned  *Ip6AndPort
-	reported []*Ip6AndPort
+	learned  *V6AddrPort
+	reported []*V6AddrPort
 }
 }
 
 
 type hostnamePort struct {
 type hostnamePort struct {
@@ -170,7 +170,7 @@ func (hr *hostnamesResults) Cancel() {
 	}
 	}
 }
 }
 
 
-func (hr *hostnamesResults) GetIPs() []netip.AddrPort {
+func (hr *hostnamesResults) GetAddrs() []netip.AddrPort {
 	var retSlice []netip.AddrPort
 	var retSlice []netip.AddrPort
 	if hr != nil {
 	if hr != nil {
 		p := hr.ips.Load()
 		p := hr.ips.Load()
@@ -189,6 +189,9 @@ type RemoteList struct {
 	// Every interaction with internals requires a lock!
 	// Every interaction with internals requires a lock!
 	sync.RWMutex
 	sync.RWMutex
 
 
+	// The full list of vpn addresses assigned to this host
+	vpnAddrs []netip.Addr
+
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
 	addrs []netip.AddrPort
 	addrs []netip.AddrPort
 
 
@@ -212,13 +215,16 @@ type RemoteList struct {
 }
 }
 
 
 // NewRemoteList creates a new empty RemoteList
 // NewRemoteList creates a new empty RemoteList
-func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
-	return &RemoteList{
+func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList {
+	r := &RemoteList{
+		vpnAddrs:  make([]netip.Addr, len(vpnAddrs)),
 		addrs:     make([]netip.AddrPort, 0),
 		addrs:     make([]netip.AddrPort, 0),
 		relays:    make([]netip.Addr, 0),
 		relays:    make([]netip.Addr, 0),
 		cache:     make(map[netip.Addr]*cache),
 		cache:     make(map[netip.Addr]*cache),
 		shouldAdd: shouldAdd,
 		shouldAdd: shouldAdd,
 	}
 	}
+	copy(r.vpnAddrs, vpnAddrs)
+	return r
 }
 }
 
 
 func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
 func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
@@ -273,9 +279,9 @@ func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
 	r.Lock()
 	r.Lock()
 	defer r.Unlock()
 	defer r.Unlock()
 	if remote.Addr().Is4() {
 	if remote.Addr().Is4() {
-		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
+		r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port()))
 	} else {
 	} else {
-		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
+		r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port()))
 	}
 	}
 }
 }
 
 
@@ -304,21 +310,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
 
 
 		if mc.v4 != nil {
 		if mc.v4 != nil {
 			if mc.v4.learned != nil {
 			if mc.v4.learned != nil {
-				c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
+				c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned))
 			}
 			}
 
 
 			for _, a := range mc.v4.reported {
 			for _, a := range mc.v4.reported {
-				c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
+				c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a))
 			}
 			}
 		}
 		}
 
 
 		if mc.v6 != nil {
 		if mc.v6 != nil {
 			if mc.v6.learned != nil {
 			if mc.v6.learned != nil {
-				c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
+				c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned))
 			}
 			}
 
 
 			for _, a := range mc.v6.reported {
 			for _, a := range mc.v6.reported {
-				c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
+				c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a))
 			}
 			}
 		}
 		}
 
 
@@ -401,14 +407,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
 
 
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 }
 }
 
 
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 
@@ -436,12 +442,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A
 
 
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 
 	// We are doing the easy append because this is rarely called
 	// We are doing the easy append because this is rarely called
-	c.reported = append([]*Ip4AndPort{to}, c.reported...)
+	c.reported = append([]*V4AddrPort{to}, c.reported...)
 	if len(c.reported) > MaxRemotes {
 	if len(c.reported) > MaxRemotes {
 		c.reported = c.reported[:MaxRemotes]
 		c.reported = c.reported[:MaxRemotes]
 	}
 	}
@@ -449,14 +455,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
 
 
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 }
 }
 
 
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 
@@ -473,12 +479,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor
 
 
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 
 	// We are doing the easy append because this is rarely called
 	// We are doing the easy append because this is rarely called
-	c.reported = append([]*Ip6AndPort{to}, c.reported...)
+	c.reported = append([]*V6AddrPort{to}, c.reported...)
 	if len(c.reported) > MaxRemotes {
 	if len(c.reported) > MaxRemotes {
 		c.reported = c.reported[:MaxRemotes]
 		c.reported = c.reported[:MaxRemotes]
 	}
 	}
@@ -536,14 +542,14 @@ func (r *RemoteList) unlockedCollect() {
 	for _, c := range r.cache {
 	for _, c := range r.cache {
 		if c.v4 != nil {
 		if c.v4 != nil {
 			if c.v4.learned != nil {
 			if c.v4.learned != nil {
-				u := AddrPortFromIp4AndPort(c.v4.learned)
+				u := protoV4AddrPortToNetAddrPort(c.v4.learned)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
 			}
 			}
 
 
 			for _, v := range c.v4.reported {
 			for _, v := range c.v4.reported {
-				u := AddrPortFromIp4AndPort(v)
+				u := protoV4AddrPortToNetAddrPort(v)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
@@ -552,14 +558,14 @@ func (r *RemoteList) unlockedCollect() {
 
 
 		if c.v6 != nil {
 		if c.v6 != nil {
 			if c.v6.learned != nil {
 			if c.v6.learned != nil {
-				u := AddrPortFromIp6AndPort(c.v6.learned)
+				u := protoV6AddrPortToNetAddrPort(c.v6.learned)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
 			}
 			}
 
 
 			for _, v := range c.v6.reported {
 			for _, v := range c.v6.reported {
-				u := AddrPortFromIp6AndPort(v)
+				u := protoV6AddrPortToNetAddrPort(v)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
@@ -573,7 +579,7 @@ func (r *RemoteList) unlockedCollect() {
 		}
 		}
 	}
 	}
 
 
-	dnsAddrs := r.hr.GetIPs()
+	dnsAddrs := r.hr.GetAddrs()
 	for _, addr := range dnsAddrs {
 	for _, addr := range dnsAddrs {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
 			if !r.unlockedIsBad(addr) {
 			if !r.unlockedIsBad(addr) {

+ 20 - 20
remote_list_test.go

@@ -9,11 +9,11 @@ import (
 )
 )
 
 
 func TestRemoteList_Rebuild(t *testing.T) {
 func TestRemoteList_Rebuild(t *testing.T) {
-	rl := NewRemoteList(nil)
+	rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip4AndPort{
+		[]*V4AddrPort{
 			newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
 			newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
 			newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
@@ -25,20 +25,20 @@ func TestRemoteList_Rebuild(t *testing.T) {
 			newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
 			newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
 			newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
 			newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
 		netip.MustParseAddr("0.0.0.1"),
 		netip.MustParseAddr("0.0.0.1"),
 		netip.MustParseAddr("0.0.0.1"),
 		netip.MustParseAddr("0.0.0.1"),
-		[]*Ip6AndPort{
+		[]*V6AddrPort{
 			newIp6AndPortFromString("[1::1]:1"), // this is duped
 			newIp6AndPortFromString("[1::1]:1"), // this is duped
 			newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
 			newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 			newIp6AndPortFromString("[1::1]:2"), // this is a dupe
 			newIp6AndPortFromString("[1::1]:2"), // this is a dupe
 		},
 		},
-		func(netip.Addr, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *V6AddrPort) bool { return true },
 	)
 	)
 
 
 	rl.Rebuild([]netip.Prefix{})
 	rl.Rebuild([]netip.Prefix{})
@@ -98,11 +98,11 @@ func TestRemoteList_Rebuild(t *testing.T) {
 }
 }
 
 
 func BenchmarkFullRebuild(b *testing.B) {
 func BenchmarkFullRebuild(b *testing.B) {
-	rl := NewRemoteList(nil)
+	rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip4AndPort{
+		[]*V4AddrPort{
 			newIp4AndPortFromString("70.199.182.92:1475"),
 			newIp4AndPortFromString("70.199.182.92:1475"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"),
@@ -112,19 +112,19 @@ func BenchmarkFullRebuild(b *testing.B) {
 			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
 			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
 			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip6AndPort{
+		[]*V6AddrPort{
 			newIp6AndPortFromString("[1::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"),
 			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
 			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
 		},
-		func(netip.Addr, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *V6AddrPort) bool { return true },
 	)
 	)
 
 
 	b.Run("no preferred", func(b *testing.B) {
 	b.Run("no preferred", func(b *testing.B) {
@@ -160,11 +160,11 @@ func BenchmarkFullRebuild(b *testing.B) {
 }
 }
 
 
 func BenchmarkSortRebuild(b *testing.B) {
 func BenchmarkSortRebuild(b *testing.B) {
-	rl := NewRemoteList(nil)
+	rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip4AndPort{
+		[]*V4AddrPort{
 			newIp4AndPortFromString("70.199.182.92:1475"),
 			newIp4AndPortFromString("70.199.182.92:1475"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.0.182:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"),
 			newIp4AndPortFromString("172.17.1.1:10101"),
@@ -174,19 +174,19 @@ func BenchmarkSortRebuild(b *testing.B) {
 			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
 			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
 			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
 		},
-		func(netip.Addr, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *V4AddrPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
 		netip.MustParseAddr("0.0.0.0"),
-		[]*Ip6AndPort{
+		[]*V6AddrPort{
 			newIp6AndPortFromString("[1::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"),
 			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
 			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1:100::1]:1"),
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
 		},
-		func(netip.Addr, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *V6AddrPort) bool { return true },
 	)
 	)
 
 
 	b.Run("no preferred", func(b *testing.B) {
 	b.Run("no preferred", func(b *testing.B) {
@@ -224,19 +224,19 @@ func BenchmarkSortRebuild(b *testing.B) {
 	})
 	})
 }
 }
 
 
-func newIp4AndPortFromString(s string) *Ip4AndPort {
+func newIp4AndPortFromString(s string) *V4AddrPort {
 	a := netip.MustParseAddrPort(s)
 	a := netip.MustParseAddrPort(s)
 	v4Addr := a.Addr().As4()
 	v4Addr := a.Addr().As4()
-	return &Ip4AndPort{
-		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+	return &V4AddrPort{
+		Addr: binary.BigEndian.Uint32(v4Addr[:]),
 		Port: uint32(a.Port()),
 		Port: uint32(a.Port()),
 	}
 	}
 }
 }
 
 
-func newIp6AndPortFromString(s string) *Ip6AndPort {
+func newIp6AndPortFromString(s string) *V6AddrPort {
 	a := netip.MustParseAddrPort(s)
 	a := netip.MustParseAddrPort(s)
 	v6Addr := a.Addr().As16()
 	v6Addr := a.Addr().As16()
-	return &Ip6AndPort{
+	return &V6AddrPort{
 		Hi:   binary.BigEndian.Uint64(v6Addr[:8]),
 		Hi:   binary.BigEndian.Uint64(v6Addr[:8]),
 		Lo:   binary.BigEndian.Uint64(v6Addr[8:]),
 		Lo:   binary.BigEndian.Uint64(v6Addr[8:]),
 		Port: uint32(a.Port()),
 		Port: uint32(a.Port()),

+ 2 - 2
service/service.go

@@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) {
 		},
 		},
 	})
 	})
 
 
-	ipNet := device.Cidr()
+	ipNet := device.Networks()
 	pa := tcpip.ProtocolAddress{
 	pa := tcpip.ProtocolAddress{
-		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
+		AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(),
 		Protocol:          ipv4.ProtocolNumber,
 		Protocol:          ipv4.ProtocolNumber,
 	}
 	}
 	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
 	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

+ 1 - 1
service/service_test.go

@@ -19,7 +19,7 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
 func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
-	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
+	_, _, myPrivKey, myPEM := e2e.NewTestCert(cert.Version2, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{})
 	caB, err := caCrt.MarshalPEM()
 	caB, err := caCrt.MarshalPEM()
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)

+ 20 - 17
ssh.go

@@ -430,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 	}
 	}
 
 
 	sort.Slice(hm, func(i, j int) bool {
 	sort.Slice(hm, func(i, j int) bool {
-		return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
+		return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0
 	})
 	})
 
 
 	if fs.Json || fs.Pretty {
 	if fs.Json || fs.Pretty {
@@ -447,7 +447,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 
 
 	} else {
 	} else {
 		for _, v := range hm {
 		for _, v := range hm {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
+			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs))
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -581,7 +581,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}
@@ -622,12 +622,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo != nil {
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 	}
 
 
-	hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
+	hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp)
 	if hostInfo != nil {
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
 	}
@@ -677,7 +677,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}
@@ -785,7 +785,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 		return nil
 		return nil
 	}
 	}
 
 
-	cert := ifce.pki.GetCertState().Certificate
+	//TODO: This should return both certs
+	cert := ifce.pki.getDefaultCertificate()
 	if len(a) > 0 {
 	if len(a) > 0 {
 		vpnIp, err := netip.ParseAddr(a[0])
 		vpnIp, err := netip.ParseAddr(a[0])
 		if err != nil {
 		if err != nil {
@@ -796,7 +797,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 		}
 
 
-		hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+		hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
 		if hostInfo == nil {
 		if hostInfo == nil {
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		}
 		}
@@ -880,16 +881,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	}
 	}
 
 
 	for k, v := range relays {
 	for k, v := range relays {
-		ro := RelayOutput{NebulaIp: v.vpnIp}
+		ro := RelayOutput{NebulaIp: v.vpnAddrs[0]}
 		co.Relays = append(co.Relays, &ro)
 		co.Relays = append(co.Relays, &ro)
-		relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
+		relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0])
 		if relayHI == nil {
 		if relayHI == nil {
 			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
 			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
 			continue
 			continue
 		}
 		}
 		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
 		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
 			rf := RelayFor{Error: nil}
 			rf := RelayFor{Error: nil}
-			r, ok := relayHI.relayState.GetRelayForByIp(vpnIp)
+			r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp)
 			if ok {
 			if ok {
 				t := ""
 				t := ""
 				switch r.Type {
 				switch r.Type {
@@ -913,14 +914,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 
 
 				rf.LocalIndex = r.LocalIndex
 				rf.LocalIndex = r.LocalIndex
 				rf.RemoteIndex = r.RemoteIndex
 				rf.RemoteIndex = r.RemoteIndex
-				rf.PeerIp = r.PeerIp
+				rf.PeerIp = r.PeerAddr
 				rf.Type = t
 				rf.Type = t
 				rf.State = s
 				rf.State = s
 				if rf.LocalIndex != k {
 				if rf.LocalIndex != k {
 					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
 					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
 				}
 				}
 			}
 			}
-			relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
+			relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp)
 			if relayedHI != nil {
 			if relayedHI != nil {
 				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
 				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
 			}
 			}
@@ -955,7 +956,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}
@@ -971,13 +972,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
 func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
 
 
 	data := struct {
 	data := struct {
-		Name string `json:"name"`
-		Cidr string `json:"cidr"`
+		Name string         `json:"name"`
+		Cidr []netip.Prefix `json:"cidr"`
 	}{
 	}{
 		Name: ifce.inside.Name(),
 		Name: ifce.inside.Name(),
-		Cidr: ifce.inside.Cidr().String(),
+		Cidr: make([]netip.Prefix, len(ifce.inside.Networks())),
 	}
 	}
 
 
+	copy(data.Cidr, ifce.inside.Networks())
+
 	flags, ok := fs.(*sshDeviceInfoFlags)
 	flags, ok := fs.(*sshDeviceInfoFlags)
 	if !ok {
 	if !ok {
 		return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)
 		return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)

+ 2 - 2
test/tun.go

@@ -16,8 +16,8 @@ func (NoopTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
-func (NoopTun) Cidr() netip.Prefix {
-	return netip.Prefix{}
+func (NoopTun) Networks() []netip.Prefix {
+	return []netip.Prefix{}
 }
 }
 
 
 func (NoopTun) Name() string {
 func (NoopTun) Name() string {

+ 4 - 4
timeout_test.go

@@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)
 
 
 	fps := []firewall.Packet{
 	fps := []firewall.Packet{
-		{LocalIP: netip.MustParseAddr("0.0.0.1")},
-		{LocalIP: netip.MustParseAddr("0.0.0.2")},
-		{LocalIP: netip.MustParseAddr("0.0.0.3")},
-		{LocalIP: netip.MustParseAddr("0.0.0.4")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.1")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.2")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.3")},
+		{LocalAddr: netip.MustParseAddr("0.0.0.4")},
 	}
 	}
 
 
 	tw.Add(fps[0], time.Second*1)
 	tw.Add(fps[0], time.Second*1)

+ 3 - 12
udp/conn.go

@@ -4,28 +4,19 @@ import (
 	"net/netip"
 	"net/netip"
 
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/header"
 )
 )
 
 
 const MTU = 9001
 const MTU = 9001
 
 
 type EncReader func(
 type EncReader func(
 	addr netip.AddrPort,
 	addr netip.AddrPort,
-	out []byte,
-	packet []byte,
-	header *header.H,
-	fwPacket *firewall.Packet,
-	lhh LightHouseHandlerFunc,
-	nb []byte,
-	q int,
-	localCache firewall.ConntrackCache,
+	payload []byte,
 )
 )
 
 
 type Conn interface {
 type Conn interface {
 	Rebind() error
 	Rebind() error
 	LocalAddr() (netip.AddrPort, error)
 	LocalAddr() (netip.AddrPort, error)
-	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
+	ListenOut(r EncReader)
 	WriteTo(b []byte, addr netip.AddrPort) error
 	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
 	ReloadConfig(c *config.C)
 	Close() error
 	Close() error
@@ -39,7 +30,7 @@ func (NoopConn) Rebind() error {
 func (NoopConn) LocalAddr() (netip.AddrPort, error) {
 func (NoopConn) LocalAddr() (netip.AddrPort, error) {
 	return netip.AddrPort{}, nil
 	return netip.AddrPort{}, nil
 }
 }
-func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
+func (NoopConn) ListenOut(_ EncReader) {
 	return
 	return
 }
 }
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {

+ 0 - 10
udp/temp.go

@@ -1,10 +0,0 @@
-package udp
-
-import (
-	"net/netip"
-)
-
-//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
-
-// TODO: IPV6-WORK this can likely be removed now
-type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

+ 2 - 18
udp/udp_generic.go

@@ -15,8 +15,6 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/header"
 )
 )
 
 
 type GenericConn struct {
 type GenericConn struct {
@@ -72,12 +70,8 @@ type rawMessage struct {
 	Len uint32
 	Len uint32
 }
 }
 
 
-func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
-	plaintext := make([]byte, MTU)
+func (u *GenericConn) ListenOut(r EncReader) {
 	buffer := make([]byte, MTU)
 	buffer := make([]byte, MTU)
-	h := &header.H{}
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
 
 
 	for {
 	for {
 		// Just read one packet at a time
 		// Just read one packet at a time
@@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
 			return
 			return
 		}
 		}
 
 
-		r(
-			netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
-			plaintext[:0],
-			buffer[:n],
-			h,
-			fwPacket,
-			lhf,
-			nb,
-			q,
-			cache.Get(u.l),
-		)
+		r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
 	}
 	}
 }
 }

+ 3 - 23
udp/udp_linux.go

@@ -14,8 +14,6 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/header"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
@@ -120,15 +118,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 	}
 	}
 }
 }
 
 
-func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
-	plaintext := make([]byte, MTU)
-	h := &header.H{}
-	fwPacket := &firewall.Packet{}
+func (u *StdConn) ListenOut(r EncReader) {
 	var ip netip.Addr
 	var ip netip.Addr
-	nb := make([]byte, 12, 12)
 
 
-	//TODO: should we track this?
-	//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
 	msgs, buffers, names := u.PrepareRawMessages(u.batch)
 	msgs, buffers, names := u.PrepareRawMessages(u.batch)
 	read := u.ReadMulti
 	read := u.ReadMulti
 	if u.batch == 1 {
 	if u.batch == 1 {
@@ -142,26 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 			return
 			return
 		}
 		}
 
 
-		//metric.Update(int64(n))
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
+			// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
 			if u.isV4 {
 			if u.isV4 {
 				ip, _ = netip.AddrFromSlice(names[i][4:8])
 				ip, _ = netip.AddrFromSlice(names[i][4:8])
-				//TODO: IPV6-WORK what is not ok?
 			} else {
 			} else {
 				ip, _ = netip.AddrFromSlice(names[i][8:24])
 				ip, _ = netip.AddrFromSlice(names[i][8:24])
-				//TODO: IPV6-WORK what is not ok?
 			}
 			}
-			r(
-				netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
-				plaintext[:0],
-				buffers[i][:msgs[i].Len],
-				h,
-				fwPacket,
-				lhf,
-				nb,
-				q,
-				cache.Get(u.l),
-			)
+			r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len])
 		}
 		}
 	}
 	}
 }
 }

+ 2 - 19
udp/udp_rio_windows.go

@@ -18,9 +18,6 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/header"
-
 	"golang.org/x/sys/windows"
 	"golang.org/x/sys/windows"
 	"golang.zx2c4.com/wireguard/conn/winrio"
 	"golang.zx2c4.com/wireguard/conn/winrio"
 )
 )
@@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error {
 	return nil
 	return nil
 }
 }
 
 
-func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
-	plaintext := make([]byte, MTU)
+func (u *RIOConn) ListenOut(r EncReader) {
 	buffer := make([]byte, MTU)
 	buffer := make([]byte, MTU)
-	h := &header.H{}
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
 
 
 	for {
 	for {
 		// Just read one packet at a time
 		// Just read one packet at a time
@@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 			return
 			return
 		}
 		}
 
 
-		r(
-			netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
-			plaintext[:0],
-			buffer[:n],
-			h,
-			fwPacket,
-			lhf,
-			nb,
-			q,
-			cache.Get(u.l),
-		)
+		r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n])
 	}
 	}
 }
 }
 
 

+ 2 - 8
udp/udp_tester.go

@@ -10,7 +10,6 @@ import (
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 )
 )
 
 
@@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	return nil
 	return nil
 }
 }
 
 
-func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
-	plaintext := make([]byte, MTU)
-	h := &header.H{}
-	fwPacket := &firewall.Packet{}
-	nb := make([]byte, 12, 12)
-
+func (u *TesterConn) ListenOut(r EncReader) {
 	for {
 	for {
 		p, ok := <-u.RxPackets
 		p, ok := <-u.RxPackets
 		if !ok {
 		if !ok {
 			return
 			return
 		}
 		}
-		r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(p.From, p.Data)
 	}
 	}
 }
 }
 
 

Some files were not shown because too many files changed in this diff