Przeglądaj źródła

Switch most everything to netip in prep for ipv6 in the overlay (#1173)

Nate Brown 1 rok temu
rodzic
commit
e264a0ff88
79 zmienionych plików z 1896 dodań i 2678 usunięć
  1. 26 65
      allow_list.go
  2. 21 21
      allow_list_test.go
  3. 41 25
      calculated_remote.go
  4. 7 9
      calculated_remote_test.go
  5. 0 10
      cidr/parse.go
  6. 0 203
      cidr/tree4.go
  7. 0 170
      cidr/tree4_test.go
  8. 0 189
      cidr/tree6.go
  9. 0 98
      cidr/tree6_test.go
  10. 17 13
      connection_manager.go
  11. 17 17
      connection_manager_test.go
  12. 19 21
      control.go
  13. 33 24
      control_test.go
  14. 20 27
      control_tester.go
  15. 12 6
      dns_server.go
  16. 169 169
      e2e/handshakes_test.go
  17. 15 8
      e2e/helpers.go
  18. 28 24
      e2e/helpers_test.go
  19. 4 4
      e2e/router/hostmap.go
  20. 36 63
      e2e/router/router.go
  21. 58 42
      firewall.go
  22. 3 4
      firewall/packet.go
  23. 74 73
      firewall_test.go
  24. 2 0
      go.mod
  25. 6 0
      go.sum
  26. 42 16
      handshake_ix.go
  27. 50 41
      handshake_manager.go
  28. 9 9
      handshake_manager_test.go
  29. 75 71
      hostmap.go
  30. 25 34
      hostmap_test.go
  31. 4 2
      hostmap_tester.go
  32. 20 24
      inside.go
  33. 36 11
      interface.go
  34. 2 0
      iputil/packet.go
  35. 0 93
      iputil/util.go
  36. 0 17
      iputil/util_test.go
  37. 212 188
      lighthouse.go
  38. 85 102
      lighthouse_test.go
  39. 22 8
      main.go
  40. 48 47
      outside.go
  41. 5 5
      outside_test.go
  42. 3 5
      overlay/device.go
  43. 19 25
      overlay/route.go
  44. 27 16
      overlay/route_test.go
  45. 5 5
      overlay/tun.go
  46. 9 10
      overlay/tun_android.go
  47. 41 18
      overlay/tun_darwin.go
  48. 6 6
      overlay/tun_disabled.go
  49. 11 12
      overlay/tun_freebsd.go
  50. 9 10
      overlay/tun_ios.go
  51. 56 35
      overlay/tun_linux.go
  52. 14 15
      overlay/tun_netbsd.go
  53. 14 15
      overlay/tun_openbsd.go
  54. 9 10
      overlay/tun_tester.go
  55. 11 11
      overlay/tun_water_windows.go
  56. 3 3
      overlay/tun_windows.go
  57. 12 38
      overlay/tun_wintun_windows.go
  58. 7 8
      overlay/user.go
  59. 2 0
      pki.go
  60. 52 31
      relay_manager.go
  61. 73 93
      remote_list.go
  62. 98 89
      remote_list_test.go
  63. 1 1
      service/service.go
  64. 6 10
      service/service_test.go
  65. 29 36
      ssh.go
  66. 5 7
      test/tun.go
  67. 5 4
      timeout_test.go
  68. 8 6
      udp/conn.go
  69. 3 2
      udp/temp.go
  70. 0 100
      udp/udp_all.go
  71. 2 1
      udp/udp_android.go
  72. 2 1
      udp/udp_bsd.go
  73. 2 1
      udp/udp_darwin.go
  74. 23 14
      udp/udp_generic.go
  75. 43 32
      udp/udp_linux.go
  76. 2 1
      udp/udp_netbsd.go
  77. 22 21
      udp/udp_rio_windows.go
  78. 17 32
      udp/udp_tester.go
  79. 2 1
      udp/udp_windows.go

+ 26 - 65
allow_list.go

@@ -2,17 +2,16 @@ package nebula
 
 import (
 	"fmt"
-	"net"
+	"net/netip"
 	"regexp"
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *cidr.Tree6[bool]
+	cidrTree *bart.Table[bool]
 }
 
 type RemoteAllowList struct {
@@ -20,7 +19,7 @@ type RemoteAllowList struct {
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
-	insideAllowLists *cidr.Tree6[*AllowList]
+	insideAllowLists *bart.Table[*AllowList]
 }
 
 type LocalAllowList struct {
@@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
 
-	tree := cidr.NewTree6[bool]()
+	tree := new(bart.Table[bool])
 
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	type allowListRules struct {
@@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
 		}
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		ipNet, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
 		}
 
+		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
+
 		// TODO: should we error on duplicate CIDRs in the config?
-		tree.AddCIDR(ipNet, value)
+		tree.Insert(ipNet, value)
 
-		maskBits, maskSize := ipNet.Mask.Size()
+		maskBits := ipNet.Bits()
 
 		var rules *allowListRules
-		if maskSize == 32 {
+		if ipNet.Addr().Is4() {
 			rules = &rules4
 		} else {
 			rules = &rules6
@@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 	if !rules4.defaultSet {
 		if rules4.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
-			tree.AddCIDR(zeroCIDR, !rules4.allValues)
+			tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
 		} else {
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
 		}
@@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 	if !rules6.defaultSet {
 		if rules6.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("::/0")
-			tree.AddCIDR(zeroCIDR, !rules6.allValues)
+			tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
 		} else {
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
 		}
@@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
 	return nameRules, nil
 }
 
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
+func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	remoteAllowRanges := cidr.NewTree6[*AllowList]()
+	remoteAllowRanges := new(bart.Table[*AllowList])
 
 	rawMap, ok := value.(map[interface{}]interface{})
 	if !ok {
@@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error
 			return nil, err
 		}
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		ipNet, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
 		}
 
-		remoteAllowRanges.AddCIDR(ipNet, allowList)
+		remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
 	}
 
 	return remoteAllowRanges, nil
 }
 
-func (al *AllowList) Allow(ip net.IP) bool {
-	if al == nil {
-		return true
-	}
-
-	_, result := al.cidrTree.MostSpecificContains(ip)
-	return result
-}
-
-func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
-	if al == nil {
-		return true
-	}
-
-	_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
-	return result
-}
-
-func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
+func (al *AllowList) Allow(ip netip.Addr) bool {
 	if al == nil {
 		return true
 	}
 
-	_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+	result, _ := al.cidrTree.Lookup(ip)
 	return result
 }
 
-func (al *LocalAllowList) Allow(ip net.IP) bool {
+func (al *LocalAllowList) Allow(ip netip.Addr) bool {
 	if al == nil {
 		return true
 	}
@@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 }
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
+func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
 	if al == nil {
 		return true
 	}
 	return al.AllowList.Allow(ip)
 }
 
-func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
+func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
 	if !al.getInsideAllowList(vpnIp).Allow(ip) {
 		return false
 	}
 	return al.AllowList.Allow(ip)
 }
 
-func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
-	if al == nil {
-		return true
-	}
-	if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
-		return false
-	}
-	return al.AllowList.AllowIpV4(ip)
-}
-
-func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
-	if al == nil {
-		return true
-	}
-	if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
-		return false
-	}
-	return al.AllowList.AllowIpV6(hi, lo)
-}
-
-func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
-		ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnIp)
 		if ok {
 			return inside
 		}

+ 21 - 21
allow_list_test.go

@@ -1,11 +1,11 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"regexp"
 	"testing"
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
@@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		"192.168.0.0": true,
 	}
 	r, err := newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
 	assert.Nil(t, r)
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
@@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) {
 }
 
 func TestAllowList_Allow(t *testing.T) {
-	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
-
-	tree := cidr.NewTree6[bool]()
-	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
-	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
-	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
-	tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
-	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
-	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
-	tree.AddCIDR(cidr.Parse("::1/128"), true)
-	tree.AddCIDR(cidr.Parse("::2/128"), false)
+	assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
+
+	tree := new(bart.Table[bool])
+	tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
+	tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
+	tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
+	tree.Insert(netip.MustParsePrefix("::1/128"), true)
+	tree.Insert(netip.MustParsePrefix("::2/128"), false)
 	al := &AllowList{cidrTree: tree}
 
-	assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
+	assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
+	assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
+	assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
 }
 
 func TestLocalAllowList_AllowName(t *testing.T) {

+ 41 - 25
calculated_remote.go

@@ -1,41 +1,36 @@
 package nebula
 
 import (
+	"encoding/binary"
 	"fmt"
 	"math"
 	"net"
+	"net/netip"
 	"strconv"
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 // This allows us to "guess" what the remote might be for a host while we wait
 // for the lighthouse response. See "lighthouse.calculated_remotes" in the
 // example config file.
 type calculatedRemote struct {
-	ipNet  net.IPNet
-	maskIP iputil.VpnIp
-	mask   iputil.VpnIp
-	port   uint32
+	ipNet netip.Prefix
+	mask  netip.Prefix
+	port  uint32
 }
 
-func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
-	// Ensure this is an IPv4 mask that we expect
-	ones, bits := ipNet.Mask.Size()
-	if ones == 0 || bits != 32 {
-		return nil, fmt.Errorf("invalid mask: %v", ipNet)
-	}
+func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
 	}
 
 	return &calculatedRemote{
-		ipNet:  *ipNet,
-		maskIP: iputil.Ip2VpnIp(ipNet.IP),
-		mask:   iputil.Ip2VpnIp(ipNet.Mask),
-		port:   uint32(port),
+		ipNet: maskCidr,
+		mask:  masked,
+		port:  uint32(port),
 	}, nil
 }
 
@@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 
-func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
+func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
 	// Combine the masked bytes of the "mask" IP with the unmasked bytes
 	// of the overlay IP
-	masked := (c.maskIP & c.mask) | (ip & ^c.mask)
+	if c.ipNet.Addr().Is4() {
+		return c.apply4(ip)
+	}
+	return c.apply6(ip)
+}
+
+func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
+	//TODO: IPV6-WORK this can be less crappy
+	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	mask := binary.BigEndian.Uint32(maskb[:])
+
+	b := c.mask.Addr().As4()
+	maskIp := binary.BigEndian.Uint32(b[:])
+
+	b = ip.As4()
+	intIp := binary.BigEndian.Uint32(b[:])
+
+	return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+}
 
-	return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
+func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
+	//TODO: IPV6-WORK
+	panic("Can not calculate ipv6 remote addresses")
 }
 
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
+	calculatedRemotes := new(bart.Table[[]*calculatedRemote])
 
 	rawMap, ok := value.(map[any]any)
 	if !ok {
@@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
 		}
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		cidr, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 
+		//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
 		entry, err := newCalculatedRemotesListFromConfig(rawValue)
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
 
-		calculatedRemotes.AddCIDR(ipNet, entry)
+		calculatedRemotes.Insert(cidr, entry)
 	}
 
 	return calculatedRemotes, nil
@@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 	if !ok {
 		return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
 	}
-	_, ipNet, err := net.ParseCIDR(rawMask)
+	maskCidr, err := netip.ParsePrefix(rawMask)
 	if err != nil {
 		return nil, fmt.Errorf("invalid mask: %s", rawMask)
 	}
@@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 
-	return newCalculatedRemote(ipNet, port)
+	return newCalculatedRemote(maskCidr, port)
 }

+ 7 - 9
calculated_remote_test.go

@@ -1,27 +1,25 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"testing"
 
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 )
 
 func TestCalculatedRemoteApply(t *testing.T) {
-	_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
+	ipNet, err := netip.ParsePrefix("192.168.1.0/24")
 	require.NoError(t, err)
 
 	c, err := newCalculatedRemote(ipNet, 4242)
 	require.NoError(t, err)
 
-	input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
+	input, err := netip.ParseAddr("10.0.10.182")
+	assert.NoError(t, err)
 
-	expected := &Ip4AndPort{
-		Ip:   uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
-		Port: 4242,
-	}
+	expected, err := netip.ParseAddr("192.168.1.182")
+	assert.NoError(t, err)
 
-	assert.Equal(t, expected, c.Apply(input))
+	assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
 }

+ 0 - 10
cidr/parse.go

@@ -1,10 +0,0 @@
-package cidr
-
-import "net"
-
-// Parse is a convenience function that returns only the IPNet
-// This function ignores errors since it is primarily a test helper, the result could be nil
-func Parse(s string) *net.IPNet {
-	_, c, _ := net.ParseCIDR(s)
-	return c
-}

+ 0 - 203
cidr/tree4.go

@@ -1,203 +0,0 @@
-package cidr
-
-import (
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-type Node[T any] struct {
-	left     *Node[T]
-	right    *Node[T]
-	parent   *Node[T]
-	hasValue bool
-	value    T
-}
-
-type entry[T any] struct {
-	CIDR  *net.IPNet
-	Value T
-}
-
-type Tree4[T any] struct {
-	root *Node[T]
-	list []entry[T]
-}
-
-const (
-	startbit = iputil.VpnIp(0x80000000)
-)
-
-func NewTree4[T any]() *Tree4[T] {
-	tree := new(Tree4[T])
-	tree.root = &Node[T]{}
-	tree.list = []entry[T]{}
-	return tree
-}
-
-func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
-	bit := startbit
-	node := tree.root
-	next := tree.root
-
-	ip := iputil.Ip2VpnIp(cidr.IP)
-	mask := iputil.Ip2VpnIp(cidr.Mask)
-
-	// Find our last ancestor in the tree
-	for bit&mask != 0 {
-		if ip&bit != 0 {
-			next = node.right
-		} else {
-			next = node.left
-		}
-
-		if next == nil {
-			break
-		}
-
-		bit = bit >> 1
-		node = next
-	}
-
-	// We already have this range so update the value
-	if next != nil {
-		addCIDR := cidr.String()
-		for i, v := range tree.list {
-			if addCIDR == v.CIDR.String() {
-				tree.list = append(tree.list[:i], tree.list[i+1:]...)
-				break
-			}
-		}
-
-		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-		node.value = val
-		node.hasValue = true
-		return
-	}
-
-	// Build up the rest of the tree we don't already have
-	for bit&mask != 0 {
-		next = &Node[T]{}
-		next.parent = node
-
-		if ip&bit != 0 {
-			node.right = next
-		} else {
-			node.left = next
-		}
-
-		bit >>= 1
-		node = next
-	}
-
-	// Final node marks our cidr, set the value
-	node.value = val
-	node.hasValue = true
-	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-}
-
-// Contains finds the first match, which may be the least specific
-func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			return true, node.value
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-
-	}
-
-	return false, value
-}
-
-// MostSpecificContains finds the most specific match
-func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			value = node.value
-			ok = true
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return ok, value
-}
-
-type eachFunc[T any] func(T) bool
-
-// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
-// The final return value will be true if the provided function returned true
-func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			// If the each func returns true then we can exit the loop
-			if each(node.value) {
-				return true
-			}
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return false
-}
-
-// GetCIDR returns the entry added by the most recent matching AddCIDR call
-func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	ip := iputil.Ip2VpnIp(cidr.IP)
-	mask := iputil.Ip2VpnIp(cidr.Mask)
-
-	// Find our last ancestor in the tree
-	for node != nil && bit&mask != 0 {
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit = bit >> 1
-	}
-
-	if bit&mask == 0 && node != nil {
-		value = node.value
-		ok = node.hasValue
-	}
-
-	return ok, value
-}
-
-// List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4[T]) List() []entry[T] {
-	return tree.list
-}

+ 0 - 170
cidr/tree4_test.go

@@ -1,170 +0,0 @@
-package cidr
-
-import (
-	"net"
-	"testing"
-
-	"github.com/slackhq/nebula/iputil"
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_List(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
-	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
-	tree.AddCIDR(Parse("1.0.0.0/16"), "4")
-	list := tree.List()
-	assert.Len(t, list, 2)
-	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
-	assert.Equal(t, "2", list[0].Value)
-	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
-	assert.Equal(t, "4", list[1].Value)
-}
-
-func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
-	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4a", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4b", "4.1.1.2"},
-		{true, "4c", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-}
-
-func TestTree4_GetCIDR(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
-	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IPNet  *net.IPNet
-	}{
-		{true, "1", Parse("1.0.0.0/8")},
-		{true, "2", Parse("2.1.0.0/16")},
-		{true, "3", Parse("3.1.1.0/24")},
-		{true, "4a", Parse("4.1.1.0/24")},
-		{true, "4b", Parse("4.1.1.1/32")},
-		{true, "4c", Parse("4.1.2.1/32")},
-		{true, "5", Parse("254.0.0.0/4")},
-		{false, "", Parse("2.0.0.0/8")},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.GetCIDR(tt.IPNet)
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
-
-	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-}

+ 0 - 189
cidr/tree6.go

@@ -1,189 +0,0 @@
-package cidr
-
-import (
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-const startbit6 = uint64(1 << 63)
-
-type Tree6[T any] struct {
-	root4 *Node[T]
-	root6 *Node[T]
-}
-
-func NewTree6[T any]() *Tree6[T] {
-	tree := new(Tree6[T])
-	tree.root4 = &Node[T]{}
-	tree.root6 = &Node[T]{}
-	return tree
-}
-
-func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
-	var node, next *Node[T]
-
-	cidrIP, ipv4 := isIPV4(cidr.IP)
-	if ipv4 {
-		node = tree.root4
-		next = tree.root4
-
-	} else {
-		node = tree.root6
-		next = tree.root6
-	}
-
-	for i := 0; i < len(cidrIP); i += 4 {
-		ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
-		mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
-		bit := startbit
-
-		// Find our last ancestor in the tree
-		for bit&mask != 0 {
-			if ip&bit != 0 {
-				next = node.right
-			} else {
-				next = node.left
-			}
-
-			if next == nil {
-				break
-			}
-
-			bit = bit >> 1
-			node = next
-		}
-
-		// Build up the rest of the tree we don't already have
-		for bit&mask != 0 {
-			next = &Node[T]{}
-			next.parent = node
-
-			if ip&bit != 0 {
-				node.right = next
-			} else {
-				node.left = next
-			}
-
-			bit >>= 1
-			node = next
-		}
-	}
-
-	// Final node marks our cidr, set the value
-	node.value = val
-	node.hasValue = true
-}
-
-// Finds the most specific match
-func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
-	var node *Node[T]
-
-	wholeIP, ipv4 := isIPV4(ip)
-	if ipv4 {
-		node = tree.root4
-	} else {
-		node = tree.root6
-	}
-
-	for i := 0; i < len(wholeIP); i += 4 {
-		ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
-		bit := startbit
-
-		for node != nil {
-			if node.hasValue {
-				value = node.value
-				ok = true
-			}
-
-			if bit == 0 {
-				break
-			}
-
-			if ip&bit != 0 {
-				node = node.right
-			} else {
-				node = node.left
-			}
-
-			bit >>= 1
-		}
-	}
-
-	return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root4
-
-	for node != nil {
-		if node.hasValue {
-			value = node.value
-			ok = true
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
-	ip := hi
-	node := tree.root6
-
-	for i := 0; i < 2; i++ {
-		bit := startbit6
-
-		for node != nil {
-			if node.hasValue {
-				value = node.value
-				ok = true
-			}
-
-			if bit == 0 {
-				break
-			}
-
-			if ip&bit != 0 {
-				node = node.right
-			} else {
-				node = node.left
-			}
-
-			bit >>= 1
-		}
-
-		ip = lo
-	}
-
-	return ok, value
-}
-
-func isIPV4(ip net.IP) (net.IP, bool) {
-	if len(ip) == net.IPv4len {
-		return ip, true
-	}
-
-	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
-		return ip[12:16], true
-	}
-
-	return ip, false
-}
-
-func isZeros(p net.IP) bool {
-	for i := 0; i < len(p); i++ {
-		if p[i] != 0 {
-			return false
-		}
-	}
-	return true
-}

+ 0 - 98
cidr/tree6_test.go

@@ -1,98 +0,0 @@
-package cidr
-
-import (
-	"encoding/binary"
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewTree6[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4b", "4.1.1.2"},
-		{true, "4c", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{true, "6a", "1:2:0:4:1:1:1:1"},
-		{true, "6b", "1:2:0:4:5:1:1:1"},
-		{true, "6c", "1:2:0:4:5:0:0:0"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree6[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	tree.AddCIDR(Parse("::/0"), "cool6")
-	ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("::"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool6", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool6", r)
-}
-
-func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewTree6[string]()
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "6a", "1:2:0:4:1:1:1:1"},
-		{true, "6b", "1:2:0:4:5:1:1:1"},
-		{true, "6c", "1:2:0:4:5:0:0:0"},
-	}
-
-	for _, tt := range tests {
-		ip := net.ParseIP(tt.IP)
-		hi := binary.BigEndian.Uint64(ip[:8])
-		lo := binary.BigEndian.Uint64(ip[8:])
-
-		ok, r := tree.MostSpecificContainsIpV6(hi, lo)
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-}

+ 17 - 13
connection_manager.go

@@ -3,6 +3,8 @@ package nebula
 import (
 	"bytes"
 	"context"
+	"encoding/binary"
+	"net/netip"
 	"sync"
 	"time"
 
@@ -10,8 +12,6 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 type trafficDecision int
@@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
 
 		var index uint32
-		var relayFrom iputil.VpnIp
-		var relayTo iputil.VpnIp
+		var relayFrom netip.Addr
+		var relayTo netip.Addr
 		switch {
 		case ok && existing.State == Established:
 			// This relay already exists in newhostinfo, then do nothing.
@@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnIp
+				relayFrom = n.intf.myVpnNet.Addr()
 				relayTo = existing.PeerIp
 			case ForwardingType:
 				relayFrom = existing.PeerIp
@@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnIp
+				relayFrom = n.intf.myVpnNet.Addr()
 				relayTo = r.PeerIp
 			case ForwardingType:
 				relayFrom = r.PeerIp
@@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			}
 		}
 
+		//TODO: IPV6-WORK
+		relayFromB := relayFrom.As4()
+		relayToB := relayTo.As4()
+
 		// Send a CreateRelayRequest to the peer.
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         uint32(relayFrom),
-			RelayToIp:           uint32(relayTo),
+			RelayFromIp:         binary.BigEndian.Uint32(relayFromB[:]),
+			RelayToIp:           binary.BigEndian.Uint32(relayToB[:]),
 		}
 		msg, err := req.Marshal()
 		if err != nil {
@@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(req.RelayFromIp),
-				"relayTo":             iputil.VpnIp(req.RelayToIp),
+				"relayFrom":           req.RelayFromIp,
+				"relayTo":             req.RelayToIp,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
 				"vpnIp":               newhostinfo.vpnIp}).
@@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 
-	if current.vpnIp < n.intf.myVpnIp {
+	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// The remotes vpn ip is lower than mine. I will not flip.
@@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 	}
 
 	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
+		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
 			n.metricsTxPunchy.Inc(1)
 			n.intf.outside.WriteTo([]byte{1}, addr)
 		})
 
-	} else if hostinfo.remote != nil {
+	} else if hostinfo.remote.IsValid() {
 		n.metricsTxPunchy.Inc(1)
 		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}

+ 17 - 17
connection_manager_test.go

@@ -5,28 +5,26 @@ import (
 	"crypto/ed25519"
 	"crypto/rand"
 	"net"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 )
 
-var vpnIp iputil.VpnIp
-
 func newTestLighthouse() *LightHouse {
 	lh := &LightHouse{
 		l:         test.NewLogger(),
-		addrMap:   map[iputil.VpnIp]*RemoteList{},
-		queryChan: make(chan iputil.VpnIp, 10),
+		addrMap:   map[netip.Addr]*RemoteList{},
+		queryChan: make(chan netip.Addr, 10),
 	}
-	lighthouses := map[iputil.VpnIp]struct{}{}
-	staticList := map[iputil.VpnIp]struct{}{}
+	lighthouses := map[netip.Addr]struct{}{}
+	staticList := map[netip.Addr]struct{}{}
 
 	lh.lighthouses.Store(&lighthouses)
 	lh.staticList.Store(&staticList)
@@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
 	hostMap := newHostMap(l, vpncidr)
@@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 
 	// Very incomplete mock objects
 	hostMap := newHostMap(l, vpncidr)
@@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		IP:   net.IPv4(172, 1, 1, 2),
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 	hostMap := newHostMap(l, vpncidr)
 	hostMap.preferredRanges.Store(&preferredRanges)
 

+ 19 - 21
control.go

@@ -2,7 +2,7 @@ package nebula
 
 import (
 	"context"
-	"net"
+	"net/netip"
 	"os"
 	"os/signal"
 	"syscall"
@@ -10,9 +10,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
-	"github.com/slackhq/nebula/udp"
 )
 
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -21,10 +19,10 @@ import (
 type controlEach func(h *HostInfo)
 
 type controlHostLister interface {
-	QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+	QueryVpnIp(vpnIp netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
 	ForEachVpnIp(each controlEach)
-	GetPreferredRanges() []*net.IPNet
+	GetPreferredRanges() []netip.Prefix
 }
 
 type Control struct {
@@ -39,15 +37,15 @@ type Control struct {
 }
 
 type ControlHostInfo struct {
-	VpnIp                  net.IP                  `json:"vpnIp"`
+	VpnIp                  netip.Addr              `json:"vpnIp"`
 	LocalIndex             uint32                  `json:"localIndex"`
 	RemoteIndex            uint32                  `json:"remoteIndex"`
-	RemoteAddrs            []*udp.Addr             `json:"remoteAddrs"`
+	RemoteAddrs            []netip.AddrPort        `json:"remoteAddrs"`
 	Cert                   *cert.NebulaCertificate `json:"cert"`
 	MessageCounter         uint64                  `json:"messageCounter"`
-	CurrentRemote          *udp.Addr               `json:"currentRemote"`
-	CurrentRelaysToMe      []iputil.VpnIp          `json:"currentRelaysToMe"`
-	CurrentRelaysThroughMe []iputil.VpnIp          `json:"currentRelaysThroughMe"`
+	CurrentRemote          netip.AddrPort          `json:"currentRemote"`
+	CurrentRelaysToMe      []netip.Addr            `json:"currentRelaysToMe"`
+	CurrentRelaysThroughMe []netip.Addr            `json:"currentRelaysThroughMe"`
 }
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -132,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 }
 
 // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	if pending {
 		hl = c.f.handshakeManager
@@ -150,19 +149,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 }
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
 	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo == nil {
 		return nil
 	}
 
-	hostInfo.SetRemote(addr.Copy())
+	hostInfo.SetRemote(addr)
 	ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
 	return &ch
 }
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo == nil {
 		return false
@@ -205,7 +206,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	}
 
 	// Learn which hosts are being used as relays, so we can shut them down last.
-	relayingHosts := map[iputil.VpnIp]*HostInfo{}
+	relayingHosts := map[netip.Addr]*HostInfo{}
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
@@ -236,15 +237,16 @@ func (c *Control) Device() overlay.Device {
 	return c.f.inside
 }
 
-func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
+func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp.ToIP(),
+		VpnIp:                  h.vpnIp,
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
+		CurrentRemote:          h.remote,
 	}
 
 	if h.ConnectionState != nil {
@@ -255,10 +257,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 		chi.Cert = c.Copy()
 	}
 
-	if h.remote != nil {
-		chi.CurrentRemote = h.remote.Copy()
-	}
-
 	return chi
 }
 

+ 33 - 24
control_test.go

@@ -2,15 +2,14 @@ package nebula
 
 import (
 	"net"
+	"net/netip"
 	"reflect"
 	"testing"
 	"time"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
-	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, &net.IPNet{})
-	hm.preferredRanges.Store(&[]*net.IPNet{})
+	hm := newHostMap(l, netip.Prefix{})
+	hm.preferredRanges.Store(&[]netip.Prefix{})
+
+	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
+	remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
 
-	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
-	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
+		IP:   remote1.Addr().AsSlice(),
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 
 	ipNet2 := net.IPNet{
-		IP:   net.ParseIP("1:2:3:4:5:6:7:8"),
+		IP:   remote2.Addr().AsSlice(),
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 
@@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 
 	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
-	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+
+	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
+	assert.True(t, ok)
+
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
@@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp:         vpnIp,
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 	}, &Interface{})
 
+	vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
+	assert.True(t, ok)
+
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
@@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		remoteIndexId: 200,
 		localIndexId:  201,
-		vpnIp:         iputil.Ip2VpnIp(ipNet2.IP),
+		vpnIp:         vpnIp2,
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 	}, &Interface{})
@@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 	}
 
-	thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
+	thi := c.GetHostInfoByVpnIp(vpnIp, false)
 
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  net.IPv4(1, 2, 3, 4).To4(),
+		VpnIp:                  vpnIp,
 		LocalIndex:             201,
 		RemoteIndex:            200,
-		RemoteAddrs:            []*udp.Addr{remote2, remote1},
+		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
 		Cert:                   crt.Copy(),
 		MessageCounter:         0,
-		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
-		CurrentRelaysToMe:      []iputil.VpnIp{},
-		CurrentRelaysThroughMe: []iputil.VpnIp{},
+		CurrentRemote:          remote1,
+		CurrentRelaysToMe:      []netip.Addr{},
+		CurrentRelaysThroughMe: []netip.Addr{},
 	}
 
 	// Make sure we don't have any unexpected fields
 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
-	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
+	assert.EqualValues(t, &expectedInfo, thi)
+	//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
+	//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
+		thi = c.GetHostInfoByVpnIp(vpnIp2, false)
 	})
 }
 

+ 20 - 27
control_tester.go

@@ -4,14 +4,13 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 
 	"github.com/slackhq/nebula/cert"
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
@@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // This is necessary if you did not configure static hosts or are not running a lighthouse
-func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
+func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	iVpnIp := iputil.Ip2VpnIp(vpnIp)
-	if v4 := toAddr.IP.To4(); v4 != nil {
-		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+	if toAddr.Addr().Is4() {
+		remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
 	} else {
-		remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+		remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
 	}
 }
 
 // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
 // This is necessary to inform an initiator of possible relays for communicating with a responder
-func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) {
+func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 
-	iVpnIp := iputil.Ip2VpnIp(vpnIp)
-	uVpnIp := []uint32{}
-	for _, rVPnIp := range relayVpnIps {
-		uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
-	}
-
-	remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
+	remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
 }
 
 // GetFromTun will pull a packet off the tun side of nebula
@@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
+func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
+	//TODO: IPV6-WORK
 	ip := layers.IPv4{
 		Version:  4,
 		TTL:      64,
 		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().IP,
-		DstIP:    toIp,
+		SrcIP:    c.f.inside.Cidr().Addr().Unmap().AsSlice(),
+		DstIP:    toIp.Unmap().AsSlice(),
 	}
 
 	udp := layers.UDP{
@@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 
-func (c *Control) GetVpnIp() iputil.VpnIp {
-	return c.f.myVpnIp
+func (c *Control) GetVpnIp() netip.Addr {
+	return c.f.myVpnNet.Addr()
 }
 
-func (c *Control) GetUDPAddr() string {
-	return c.f.outside.(*udp.TesterConn).Addr.String()
+func (c *Control) GetUDPAddr() netip.AddrPort {
+	return c.f.outside.(*udp.TesterConn).Addr
 }
 
-func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
+	hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
 	if hostinfo == nil {
 		return false
 	}
@@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
 	return c.f.pki.GetCertState().Certificate
 }
 
-func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+func (c *Control) ReHandshake(vpnIp netip.Addr) {
 	c.f.handshakeManager.StartHandshake(vpnIp, nil)
 }

+ 12 - 6
dns_server.go

@@ -3,6 +3,7 @@ package nebula
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"strconv"
 	"strings"
 	"sync"
@@ -10,7 +11,6 @@ import (
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 // This whole thing should be rewritten to use context
@@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string {
 }
 
 func (d *dnsRecords) QueryCert(data string) string {
-	ip := net.ParseIP(data[:len(data)-1])
-	if ip == nil {
+	ip, err := netip.ParseAddr(data[:len(data)-1])
+	if err != nil {
 		return ""
 	}
-	iip := iputil.Ip2VpnIp(ip)
-	hostinfo := d.hostMap.QueryVpnIp(iip)
+
+	hostinfo := d.hostMap.QueryVpnIp(ip)
 	if hostinfo == nil {
 		return ""
 	}
+
 	q := hostinfo.GetCert()
 	if q == nil {
 		return ""
 	}
+
 	cert := q.Details
 	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
 	return c
@@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 			}
 		case dns.TypeTXT:
 			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b := net.ParseIP(a)
+			b, err := netip.ParseAddr(a)
+			if err != nil {
+				return
+			}
+
 			// We don't answer these queries from non nebula nodes or localhost
 			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
 			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {

+ 169 - 169
e2e/handshakes_test.go

@@ -5,7 +5,7 @@ package e2e
 
 import (
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 	"time"
 
@@ -13,19 +13,18 @@ import (
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"gopkg.in/yaml.v2"
 )
 
 func BenchmarkHotPath(b *testing.B) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 	// Start the servers
 	myControl.Start()
@@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) {
 	r.CancelFlowLogs()
 
 	for n := 0; n < b.N; n++ {
-		myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+		myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
 	}
 
@@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) {
 }
 
 func TestGoodHandshake(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 	// Start the servers
 	myControl.Start()
 	theirControl.Start()
 
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.WaitForType(1, 0, theirControl)
 
 	t.Log("Make sure our host infos are correct")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
 
 	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 	t.Log("Do a bidirectional tunnel test")
 	r := router.NewR(t, myControl, theirControl)
 	defer r.RenderFlow()
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
@@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) {
 }
 
 func TestWrongResponderHandshake(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
-	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
+	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
 
 	// Add their real udp addr, which should be tried after evil.
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl, evilControl)
@@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 	evilControl.Start()
 
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 		h := &header.H{}
 		err := h.Parse(p.Data)
@@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 			panic(err)
 		}
 
-		if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+		if p.To == theirUdpAddr && h.Type == 1 {
 			return router.RouteAndExit
 		}
 
@@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 	t.Log("Test the tunnel with them")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	t.Log("Flush all packets from all controllers")
 	r.FlushAll()
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 	//TODO: assert hostmaps for everyone
@@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) {
 	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
 	// But will eventually collapse down to a single tunnel
 
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
@@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Trigger a handshake to start on both me and them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 	t.Log("Get both stage 1 handshake packets")
 	myHsForThem := myControl.GetFromUDP(true)
@@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) {
 
 	r.Log("Route until they receive a message packet")
 	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 	r.Log("Their cached packet should be received by me")
 	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
-	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
 
 	r.Log("Do a bidirectional tunnel test")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	myHostmapHosts := myControl.ListHostmapHosts(false)
 	myHostmapIndexes := myControl.ListHostmapIndexes(false)
@@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) {
 	r.Log("Spin until connection manager tears down a tunnel")
 
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 	}
@@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
@@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 	theirControl.Start()
 
 	r.Log("Trigger a handshake from me to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 	r.Log("Nuke my hostmap")
 	myHostmap := myControl.GetHostmap()
-	myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
 	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
 	p = r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 	r.Log("Assert the tunnel works")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	r.Log("Wait for the dead index to go away")
 	start := len(theirControl.GetHostmap().Indexes)
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		if len(theirControl.GetHostmap().Indexes) < start {
 			break
 		}
@@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 }
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
@@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 	theirControl.Start()
 
 	r.Log("Trigger a handshake from me to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 
 	r.Log("Nuke my hostmap")
 	theirHostmap := theirControl.GetHostmap()
-	theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
 	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
 	p = r.RouteForAllUntilTxTun(myControl)
-	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
 
 	r.Log("Assert the tunnel works")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	r.Log("Wait for the dead index to go away")
 	start := len(myControl.GetHostmap().Indexes)
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		if len(myControl.GetHostmap().Indexes) < start {
 			break
 		}
@@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 }
 
 func TestRelays(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -361,31 +360,31 @@ func TestRelays(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 }
 
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
 
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
 
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) {
 	theirControl.Start()
 
 	r.Log("Get a tunnel between me and relay")
-	assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
 
 	r.Log("Get a tunnel between them and relay")
-	assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
 
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 	r.Log("Wait for a packet from them to me")
 	p := r.RouteForAllUntilTxTun(myControl)
@@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) {
 
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 	l := NewTestLogger()
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
 
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
 
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 	r.Log("Get a tunnel between me and relay")
 	l.Info("Get a tunnel between me and relay")
-	assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
 
 	r.Log("Get a tunnel between them and relay")
 	l.Info("Get a tunnel between them and relay")
-	assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
 
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	l.Info("Trigger a handshake from both them and me via relay to them and me")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 	//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 	r.Log("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 
 	t.Log("Wait until we remove extra tunnels")
 	l.Info("Wait until we remove extra tunnels")
@@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 				"theirControl": len(theirControl.GetHostmap().Indexes),
 				"relayControl": len(relayControl.GetHostmap().Indexes),
 			}).Info("Waiting for hostinfos to be removed...")
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		retries--
@@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 	r.Log("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 
 	myControl.Stop()
 	theirControl.Stop()
@@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) {
 	//
 	////TODO: assert hostmaps
 }
+
 func TestRehandshakingRelays(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) {
 
 	for {
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
-		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			r.Log("Certificate between my and relay is updated!")
@@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) {
 
 	for {
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
-		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			r.Log("Certificate between their and relay is updated!")
@@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) {
 	}
 
 	r.Log("Assert the relay tunnel still works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	// We should have two hostinfos on all sides
 	for len(myControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 	}
@@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 	}
@@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 	}
@@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) {
 
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 	for {
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
-		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			r.Log("Certificate between my and relay is updated!")
@@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 	for {
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
-		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			r.Log("Certificate between their and relay is updated!")
@@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	}
 
 	r.Log("Assert the relay tunnel still works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	// We should have two hostinfos on all sides
 	for len(myControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 	}
@@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 	}
@@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 	}
@@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 }
 
 func TestRehandshaking(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
 
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
@@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Stand up a tunnel between me and them")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
@@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) {
 	myConfig.ReloadConfigString(string(rc))
 
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
-		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			break
@@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) {
 
 	r.Log("Spin until there is only 1 tunnel")
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 	}
 
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 	// Make sure the correct tunnel won
-	c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 	assert.Contains(t, c.Cert.Details.Groups, "new group")
 
 	// We should only have a single tunnel now on both sides
@@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) {
 func TestRehandshakingLoser(t *testing.T) {
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// Should be the one with the new certificate
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
 
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
@@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Stand up a tunnel between me and them")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
-	tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
-	tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+	tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) {
 	theirConfig.ReloadConfigString(string(rc))
 
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
-		theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
 
 		_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
 		if theirNewGroup {
@@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) {
 
 	r.Log("Spin until there is only 1 tunnel")
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 	}
 
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 	// Make sure the correct tunnel won
-	theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+	theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
 	assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
 
 	// We should only have a single tunnel now on both sides
@@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) {
 	// This test forces stage 1, stage 2, stage 1 to be received by me from them
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// caused a cross-linked hostinfo
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 	// Start the servers
 	myControl.Start()
@@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) {
 	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
 
 	t.Log("Start both handshakes")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 	t.Log("Get both stage 1")
 	myStage1ForThem := myControl.GetFromUDP(true)
@@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	t.Log("Make sure the tunnel still works")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 	myControl.Stop()
 	theirControl.Stop()

+ 15 - 8
e2e/helpers.go

@@ -4,6 +4,7 @@ import (
 	"crypto/rand"
 	"io"
 	"net"
+	"net/netip"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
@@ -12,7 +13,7 @@ import (
 )
 
 // NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
 	pub, priv, err := ed25519.GenerateKey(rand.Reader)
 	if before.IsZero() {
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
@@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 	}
 
 	if len(ips) > 0 {
-		nc.Details.Ips = ips
+		nc.Details.Ips = make([]*net.IPNet, len(ips))
+		for i, ip := range ips {
+			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+		}
 	}
 
 	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
+		nc.Details.Subnets = make([]*net.IPNet, len(subnets))
+		for i, ip := range subnets {
+			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+		}
 	}
 
 	if len(groups) > 0 {
@@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 
 // NewTestCert will generate a signed certificate with the provided details.
 // Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
 	issuer, err := ca.Sha256Sum()
 	if err != nil {
 		panic(err)
@@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
 	}
 
 	pub, rawPriv := x25519Keypair()
-
+	ipb := ip.Addr().AsSlice()
 	nc := &cert.NebulaCertificate{
 		Details: cert.NebulaCertificateDetails{
-			Name:           name,
-			Ips:            []*net.IPNet{ip},
-			Subnets:        subnets,
+			Name: name,
+			Ips:  []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
+			//Subnets:        subnets,
 			Groups:         groups,
 			NotBefore:      time.Unix(before.Unix(), 0),
 			NotAfter:       time.Unix(after.Unix(), 0),

+ 28 - 24
e2e/helpers_test.go

@@ -6,7 +6,7 @@ package e2e
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"testing"
 	"time"
@@ -19,7 +19,6 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"gopkg.in/yaml.v2"
 )
@@ -27,15 +26,23 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 
-	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
-	copy(vpnIpNet.IP, udpIp)
-	vpnIpNet.IP[1] += 128
-	udpAddr := net.UDPAddr{
-		IP:   udpIp,
-		Port: 4242,
+	vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
+	if err != nil {
+		panic(err)
+	}
+
+	var udpAddr netip.AddrPort
+	if vpnIpNet.Addr().Is4() {
+		budpIp := vpnIpNet.Addr().As4()
+		budpIp[1] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+	} else {
+		budpIp := vpnIpNet.Addr().As16()
+		budpIp[13] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
 	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 
@@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		//	"try_interval": "1s",
 		//},
 		"listen": m{
-			"host": udpAddr.IP.String(),
-			"port": udpAddr.Port,
+			"host": udpAddr.Addr().String(),
+			"port": udpAddr.Port(),
 		},
 		"logging": m{
 			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
@@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 	}
 
-	return control, vpnIpNet, &udpAddr, c
+	return control, vpnIpNet, udpAddr, c
 }
 
 type doneCb func()
@@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 	}
 }
 
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 
-func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
+	hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
 	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
 
-	hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
+	hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
 	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
 
 	// Check that both vpn and real addr are correct
 	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
 	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
 
-	assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
-	assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
-
-	assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
-	assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
+	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
+	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
 
 	// Check that our indexes match
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
@@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB
 	//checkIndexes("hmB", hmB, hAinB)
 }
 
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
+func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
 
-	assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
-	assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
+	assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
 
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
 	assert.NotNil(t, udp, "No udp data found")

+ 4 - 4
e2e/router/hostmap.go

@@ -5,11 +5,11 @@ package router
 
 import (
 	"fmt"
+	"net/netip"
 	"sort"
 	"strings"
 
 	"github.com/slackhq/nebula"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type edge struct {
@@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	return r, globalLines
 }
 
-func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
-	keys := make([]iputil.VpnIp, 0, len(hosts))
+func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
+	keys := make([]netip.Addr, 0, len(hosts))
 	for key := range hosts {
 		keys = append(keys, key)
 	}
 
 	sort.SliceStable(keys, func(i, j int) bool {
-		return keys[i] > keys[j]
+		return keys[i].Compare(keys[j]) > 0
 	})
 
 	return keys

+ 36 - 63
e2e/router/router.go

@@ -6,12 +6,11 @@ package router
 import (
 	"context"
 	"fmt"
-	"net"
+	"net/netip"
 	"os"
 	"path/filepath"
 	"reflect"
 	"sort"
-	"strconv"
 	"strings"
 	"sync"
 	"testing"
@@ -21,7 +20,6 @@ import (
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"golang.org/x/exp/maps"
 )
@@ -29,18 +27,18 @@ import (
 type R struct {
 	// Simple map of the ip:port registered on a control to the control
 	// Basically a router, right?
-	controls map[string]*nebula.Control
+	controls map[netip.AddrPort]*nebula.Control
 
 	// A map for inbound packets for a control that doesn't know about this address
-	inNat map[string]*nebula.Control
+	inNat map[netip.AddrPort]*nebula.Control
 
 	// A last used map, if an inbound packet hit the inNat map then
 	// all return packets should use the same last used inbound address for the outbound sender
 	// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
-	outNat map[string]net.UDPAddr
+	outNat map[string]netip.AddrPort
 
 	// A map of vpn ip to the nebula control it belongs to
-	vpnControls map[iputil.VpnIp]*nebula.Control
+	vpnControls map[netip.Addr]*nebula.Control
 
 	ignoreFlows []ignoreFlow
 	flow        []flowEntry
@@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	}
 
 	r := &R{
-		controls:     make(map[string]*nebula.Control),
-		vpnControls:  make(map[iputil.VpnIp]*nebula.Control),
-		inNat:        make(map[string]*nebula.Control),
-		outNat:       make(map[string]net.UDPAddr),
+		controls:     make(map[netip.AddrPort]*nebula.Control),
+		vpnControls:  make(map[netip.Addr]*nebula.Control),
+		inNat:        make(map[netip.AddrPort]*nebula.Control),
+		outNat:       make(map[string]netip.AddrPort),
 		flow:         []flowEntry{},
 		ignoreFlows:  []ignoreFlow{},
 		fn:           filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	for _, c := range controls {
 		addr := c.GetUDPAddr()
 		if _, ok := r.controls[addr]; ok {
-			panic("Duplicate listen address: " + addr)
+			panic("Duplicate listen address: " + addr.String())
 		}
 
 		r.vpnControls[c.GetVpnIp()] = c
@@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 // It does not look at the addr attached to the instance.
 // If a route is used, this will behave like a NAT for the return path.
 // Rewriting the source ip:port to what was last sent to from the origin
-func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
+func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
 	r.Lock()
 	defer r.Unlock()
 
-	inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
+	inAddr := netip.AddrPortFrom(ip, port)
 	if _, ok := r.inNat[inAddr]; ok {
-		panic("Duplicate listen address inNat: " + inAddr)
+		panic("Duplicate listen address inNat: " + inAddr.String())
 	}
 	r.inNat[inAddr] = c
 }
@@ -198,7 +196,7 @@ func (r *R) renderFlow() {
 		panic(err)
 	}
 
-	var participants = map[string]struct{}{}
+	var participants = map[netip.AddrPort]struct{}{}
 	var participantsVals []string
 
 	fmt.Fprintln(f, "```mermaid")
@@ -215,7 +213,7 @@ func (r *R) renderFlow() {
 			continue
 		}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr, ":", "-", 1)
+		sanAddr := strings.Replace(addr.String(), ":", "-", 1)
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
@@ -252,9 +250,9 @@ func (r *R) renderFlow() {
 
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
+				strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
 				line,
-				strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+				strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 		}
@@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp() > c[j].GetVpnIp()
+		return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
 	})
 
 	s := renderHostmaps(c...)
@@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 
 		// Nope, lets push the sender along
 		case p := <-udpTx:
-			outAddr := sender.GetUDPAddr()
 			r.Lock()
-			inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-			c := r.getControl(outAddr, inAddr, p)
+			c := r.getControl(sender.GetUDPAddr(), p.To, p)
 			if c == nil {
 				r.Unlock()
 				panic("No control for udp tx")
@@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
-			outAddr := cm[x].GetUDPAddr()
-
-			inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-			c := r.getControl(outAddr, inAddr, p)
+			c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 			if c == nil {
 				r.Unlock()
 				panic("No control for udp tx")
@@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 			panic(err)
 		}
 
-		outAddr := sender.GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
 		if receiver == nil {
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't RouteExitFunc for host: " + p.To.String())
 		}
 
 		e := whatDo(p, receiver)
@@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
 	if finish == KeepRouting {
 		finish = RouteAndExit
 	}
 
 	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
-		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
+		if p.To == toAddr {
 			return finish
 		}
 
@@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
 		r.Lock()
 
 		p := rx.Interface().(*udp.Packet)
-
-		outAddr := cm[x].GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 		if receiver == nil {
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't RouteForAllExitFunc for host: " + p.To.String())
 		}
 
 		e := whatDo(p, receiver)
@@ -697,12 +685,10 @@ func (r *R) FlushAll() {
 
 		p := rx.Interface().(*udp.Packet)
 
-		outAddr := cm[x].GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 		if receiver == nil {
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't FlushAll for host: " + p.To.String())
 		}
 		r.Unlock()
 	}
@@ -710,28 +696,14 @@ func (r *R) FlushAll() {
 
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
-	if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
-		p.FromIp = newAddr.IP
-		p.FromPort = uint16(newAddr.Port)
+func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
+	if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
+		p.From = newAddr
 	}
 
 	c, ok := r.inNat[toAddr]
 	if ok {
-		sHost, sPort, err := net.SplitHostPort(toAddr)
-		if err != nil {
-			panic(err)
-		}
-
-		port, err := strconv.Atoi(sPort)
-		if err != nil {
-			panic(err)
-		}
-
-		r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
-			IP:   net.ParseIP(sHost),
-			Port: port,
-		}
+		r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
 		return c
 	}
 
@@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string {
 	}
 
 	from := "unknown"
-	if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok {
-		from = c.GetUDPAddr()
+	srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
+	if c, ok := r.vpnControls[srcAddr]; ok {
+		from = c.GetUDPAddr().String()
 	}
 
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
@@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string {
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
 		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+		strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
 		udp.SrcPort,
 		udp.DstPort,
 		string(data.Payload()),

+ 58 - 42
firewall.go

@@ -6,23 +6,23 @@ import (
 	"errors"
 	"fmt"
 	"hash/fnv"
-	"net"
+	"net/netip"
 	"reflect"
 	"strconv"
 	"strings"
 	"sync"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 )
 
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
 }
 
 type conn struct {
@@ -52,8 +52,8 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps     *cidr.Tree4[struct{}]
-	assignedCIDR *net.IPNet
+	localIps     *bart.Table[struct{}]
+	assignedCIDR netip.Prefix
 	hasSubnets   bool
 
 	rules        string
@@ -108,7 +108,7 @@ type FirewallRule struct {
 	Any    *firewallLocalCIDR
 	Hosts  map[string]*firewallLocalCIDR
 	Groups []*firewallGroups
-	CIDR   *cidr.Tree4[*firewallLocalCIDR]
+	CIDR   *bart.Table[*firewallLocalCIDR]
 }
 
 type firewallGroups struct {
@@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
 
 type firewallLocalCIDR struct {
 	Any       bool
-	LocalCIDR *cidr.Tree4[struct{}]
+	LocalCIDR *bart.Table[struct{}]
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 	}
 
-	localIps := cidr.NewTree4[struct{}]()
-	var assignedCIDR *net.IPNet
+	localIps := new(bart.Table[struct{}])
+	var assignedCIDR netip.Prefix
+	var assignedSet bool
 	for _, ip := range c.Details.Ips {
-		ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
-		localIps.AddCIDR(ipNet, struct{}{})
+		//TODO: IPV6-WORK the unmap is a bit unfortunate
+		nip, _ := netip.AddrFromSlice(ip.IP)
+		nip = nip.Unmap()
+		nprefix := netip.PrefixFrom(nip, nip.BitLen())
+		localIps.Insert(nprefix, struct{}{})
 
-		if assignedCIDR == nil {
+		if !assignedSet {
 			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = ipNet
+			assignedCIDR = nprefix
+			assignedSet = true
 		}
 	}
 
 	for _, n := range c.Details.Subnets {
-		localIps.AddCIDR(n, struct{}{})
+		nip, _ := netip.AddrFromSlice(n.IP)
+		ones, _ := n.Mask.Size()
+		nip = nip.Unmap()
+		localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
 	}
 
 	return &Firewall{
@@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
 }
 
 // AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// https://github.com/golang/go/issues/14131
 	sIp := ""
-	if ip != nil {
+	if ip.IsValid() {
 		sIp = ip.String()
 	}
 	lIp := ""
-	if localIp != nil {
+	if localIp.IsValid() {
 		lIp = localIp.String()
 	}
 
@@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 		}
 
-		var cidr *net.IPNet
+		var cidr netip.Prefix
 		if r.Cidr != "" {
-			_, cidr, err = net.ParseCIDR(r.Cidr)
+			cidr, err = netip.ParsePrefix(r.Cidr)
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
 			}
 		}
 
-		var localCidr *net.IPNet
+		var localCidr netip.Prefix
 		if r.LocalCidr != "" {
-			_, localCidr, err = net.ParseCIDR(r.LocalCidr)
+			localCidr, err = netip.ParsePrefix(r.LocalCidr)
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
 			}
@@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 
 	// Make sure remote address matches nebula certificate
 	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		ok, _ := remoteCidr.Contains(fp.RemoteIP)
+		//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+		_, ok := remoteCidr.Lookup(fp.RemoteIP)
 		if !ok {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
@@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	ok, _ := f.localIps.Contains(fp.LocalIP)
+	//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+	_, ok := f.localIps.Lookup(fp.LocalIP)
 	if !ok {
 		f.metrics(incoming).droppedLocalIP.Inc(1)
 		return ErrInvalidLocalIP
@@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Caller must own the connMutex lock!
 func (f *Firewall) evict(p firewall.Packet) {
-	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
 	conntrack := f.Conntrack
 	t, ok := conntrack.Conns[p]
@@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 	return false
 }
 
-func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 	}
@@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 
-func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
 	fr := func() *FirewallRule {
 		return &FirewallRule{
 			Hosts:  make(map[string]*firewallLocalCIDR),
 			Groups: make([]*firewallGroups, 0),
-			CIDR:   cidr.NewTree4[*firewallLocalCIDR](),
+			CIDR:   new(bart.Table[*firewallLocalCIDR]),
 		}
 	}
 
@@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 	return fc.CANames[s.Details.Name].match(p, c)
 }
 
-func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
 	flc := func() *firewallLocalCIDR {
 		return &firewallLocalCIDR{
-			LocalCIDR: cidr.NewTree4[struct{}](),
+			LocalCIDR: new(bart.Table[struct{}]),
 		}
 	}
 
@@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
 		fr.Hosts[host] = nlc
 	}
 
-	if ip != nil {
-		_, nlc := fr.CIDR.GetCIDR(ip)
+	if ip.IsValid() {
+		nlc, _ := fr.CIDR.Get(ip)
 		if nlc == nil {
 			nlc = flc()
 		}
@@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
 		if err != nil {
 			return err
 		}
-		fr.CIDR.AddCIDR(ip, nlc)
+		fr.CIDR.Insert(ip, nlc)
 	}
 
 	return nil
 }
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
-	if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
+	if len(groups) == 0 && host == "" && !ip.IsValid() {
 		return true
 	}
 
@@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 		return true
 	}
 
-	if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
+	if ip.IsValid() && ip.Bits() == 0 {
 		return true
 	}
 
@@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		}
 	}
 
-	return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
-		return flc.match(p, c)
+	matched := false
+	prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
+	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
+		if prefix.Contains(p.RemoteIP) && val.match(p, c) {
+			matched = true
+			return false
+		}
+		return true
 	})
+	return matched
 }
 
-func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
-	if localIp == nil {
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
+	if !localIp.IsValid() {
 		if !f.hasSubnets || f.defaultLocalCIDRAny {
 			flc.Any = true
 			return nil
 		}
 
 		localIp = f.assignedCIDR
-	} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
+	} else if localIp.Bits() == 0 {
 		flc.Any = true
 	}
 
-	flc.LocalCIDR.AddCIDR(localIp, struct{}{})
+	flc.LocalCIDR.Insert(localIp, struct{}{})
 	return nil
 }
 
@@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
 		return true
 	}
 
-	ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
 	return ok
 }
 

+ 3 - 4
firewall/packet.go

@@ -3,8 +3,7 @@ package firewall
 import (
 	"encoding/json"
 	"fmt"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 
 type m map[string]interface{}
@@ -20,8 +19,8 @@ const (
 )
 
 type Packet struct {
-	LocalIP    iputil.VpnIp
-	RemoteIP   iputil.VpnIp
+	LocalIP    netip.Addr
+	RemoteIP   netip.Addr
 	LocalPort  uint16
 	RemotePort uint16
 	Protocol   uint8

+ 74 - 73
firewall_test.go

@@ -5,13 +5,13 @@ import (
 	"errors"
 	"math"
 	"net"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
@@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
 
-	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
+	ti, err := netip.ParsePrefix("1.2.3.4/32")
+	assert.NoError(t, err)
 
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
-	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
+	_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
-	ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
+	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
+	anyIp, err := netip.ParsePrefix("0.0.0.0/0")
+	assert.NoError(t, err)
+
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
-	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
+	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 }
 
 func TestFirewall_Drop(t *testing.T) {
@@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
@@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr("1.2.3.4"),
 	}
 	h.CreateRemoteCIDR(&c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
-	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
+	assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
 	// Allow inbound
 	resetConntrack(fw)
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
@@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) {
 
 	// test remote mismatch
 	oldRemote := p.RemoteIP
-	p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
+	p.RemoteIP = netip.MustParseAddr("1.2.3.10")
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	p.RemoteIP = oldRemote
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 
@@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		TCP: firewallPort{},
 	}
 
-	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
-	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
-	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
+	pfix := netip.MustParsePrefix("172.1.1.1/32")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 
 	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
 		c := &cert.NebulaCertificate{}
-		ip, _, _ := net.ParseCIDR("9.254.254.254/32")
-		lip := iputil.Ip2VpnIp(ip)
+		ip := netip.MustParsePrefix("9.254.254.254/32")
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
 		}
 	})
 
@@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
 		}
 	})
 
@@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
 		}
 	})
 
@@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
@@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	h.CreateRemoteCIDR(&c)
 
@@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.CreateRemoteCIDR(&c1)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
 	// h1/c1 lacks the proper groups
@@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  1,
 		RemotePort: 1,
 		Protocol:   firewall.ProtoUDP,
@@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	h1.CreateRemoteCIDR(&c1)
 
@@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	h2.CreateRemoteCIDR(&c2)
 
@@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	h3.CreateRemoteCIDR(&c3)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
 	cp := cert.NewCAPool()
 
 	// c1 should pass because host match
@@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l.SetOutput(ob)
 
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
@@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	h.CreateRemoteCIDR(&c)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 
 	// Drop outbound
@@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
@@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
+	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
+	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 	// Test both group and groups
 	conf = config.NewC(l)
@@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding udp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding any rule
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding rule with cidr
-	cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
+	cidr := netip.MustParsePrefix("10.0.0.0/8")
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test adding rule with local_cidr
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
 
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
 
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
 
 	// Test single group
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test single groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 	// Test Add error
 	conf = config.NewC(l)
@@ -871,8 +872,8 @@ type addRuleCall struct {
 	endPort   int32
 	groups    []string
 	host      string
-	ip        *net.IPNet
-	localIp   *net.IPNet
+	ip        netip.Prefix
+	localIp   netip.Prefix
 	caName    string
 	caSha     string
 }
@@ -882,7 +883,7 @@ type mockFirewall struct {
 	nextCallReturn error
 }
 
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
 	mf.lastCall = addRuleCall{
 		incoming:  incoming,
 		proto:     proto,

+ 2 - 0
go.mod

@@ -38,8 +38,10 @@ require (
 
 require (
 	github.com/beorn7/perks v1.0.1 // indirect
+	github.com/bits-and-blooms/bitset v1.13.0 // indirect
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
+	github.com/gaissmai/bart v0.11.1 // indirect
 	github.com/google/btree v1.1.2 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/prometheus/client_model v0.5.0 // indirect

+ 6 - 0
go.sum

@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
+github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
 github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -24,6 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI=
+github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
+github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
+github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=

+ 42 - 16
handshake_ix.go

@@ -1,13 +1,12 @@
 package nebula
 
 import (
+	"net/netip"
 	"time"
 
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 // NOISE IX Handshakes
@@ -63,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	return true
 }
 
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	certState := f.pki.GetCertState()
 	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
@@ -99,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		e.Info("Invalid certificate from host")
 		return
 	}
-	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+
+	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+	if !ok {
+		e := f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
+
+		if f.l.Level > logrus.DebugLevel {
+			e = e.WithField("cert", remoteCert)
+		}
+
+		e.Info("Invalid vpn ip from host")
+		return
+	}
+
+	vpnIp = vpnIp.Unmap()
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 
-	if vpnIp == f.myVpnIp {
+	if vpnIp == f.myVpnNet.Addr() {
 		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
@@ -113,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		return
 	}
 
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
+	if addr.IsValid() {
+		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
 			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
@@ -138,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 	}
@@ -218,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-			if addr != nil {
+			if addr.IsValid() {
 				err := f.outside.WriteTo(msg, addr)
 				if err != nil {
 					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
@@ -284,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 
 	// Do the send
 	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-	if addr != nil {
+	if addr.IsValid() {
 		err = f.outside.WriteTo(msg, addr)
 		if err != nil {
 			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
@@ -326,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
 	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -336,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	defer hh.Unlock()
 
 	hostinfo := hh.hostinfo
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
+	if addr.IsValid() {
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
@@ -389,7 +402,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 		return true
 	}
 
-	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+	if !ok {
+		e := f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
+
+		if f.l.Level > logrus.DebugLevel {
+			e = e.WithField("cert", remoteCert)
+		}
+
+		e.Info("Invalid vpn ip from host")
+		return true
+	}
+
+	vpnIp = vpnIp.Unmap()
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
@@ -453,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	ci.eKey = NewNebulaCipherState(eKey)
 
 	// Make sure the current udpAddr being used is set for responding
-	if addr != nil {
+	if addr.IsValid() {
 		hostinfo.SetRemote(addr)
 	} else {
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)

+ 50 - 41
handshake_manager.go

@@ -6,15 +6,15 @@ import (
 	"crypto/rand"
 	"encoding/binary"
 	"errors"
-	"net"
+	"net/netip"
 	"sync"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
+	"golang.org/x/exp/slices"
 )
 
 const (
@@ -46,14 +46,14 @@ type HandshakeManager struct {
 	// Mutex for interacting with the vpnIps and indexes maps
 	sync.RWMutex
 
-	vpnIps  map[iputil.VpnIp]*HandshakeHostInfo
+	vpnIps  map[netip.Addr]*HandshakeHostInfo
 	indexes map[uint32]*HandshakeHostInfo
 
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
 	outside                udp.Conn
 	config                 HandshakeConfig
-	OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
+	OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
 	messageMetrics         *MessageMetrics
 	metricInitiated        metrics.Counter
 	metricTimedOut         metrics.Counter
@@ -61,17 +61,17 @@ type HandshakeManager struct {
 	l                      *logrus.Logger
 
 	// can be used to trigger outbound handshake for the given vpnIp
-	trigger chan iputil.VpnIp
+	trigger chan netip.Addr
 }
 
 type HandshakeHostInfo struct {
 	sync.Mutex
 
-	startTime   time.Time       // Time that we first started trying with this handshake
-	ready       bool            // Is the handshake ready
-	counter     int             // How many attempts have we made so far
-	lastRemotes []*udp.Addr     // Remotes that we sent to during the previous attempt
-	packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+	startTime   time.Time        // Time that we first started trying with this handshake
+	ready       bool             // Is the handshake ready
+	counter     int              // How many attempts have we made so far
+	lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
+	packetStore []*cachedPacket  // A set of packets to be transmitted once the handshake completes
 
 	hostinfo *HostInfo
 }
@@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
 
 func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		vpnIps:                 map[iputil.VpnIp]*HandshakeHostInfo{},
+		vpnIps:                 map[netip.Addr]*HandshakeHostInfo{},
 		indexes:                map[uint32]*HandshakeHostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
 		config:                 config,
-		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
-		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		trigger:                make(chan netip.Addr, config.triggerBuffer),
+		OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 	}
 }
 
-func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
-	if addr != nil {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+	if addr.IsValid() {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
@@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	}
 }
 
-func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
 	hh := hm.queryVpnIp(vpnIp)
 	if hh == nil {
 		return
@@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	}
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
-	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
+	remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// This is a very specific optimization for a fast lighthouse reply.
@@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
-	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
+	var sentTo []netip.AddrPort
+	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
 		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
@@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
+			if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
 				continue
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
-			if relayHostInfo == nil || relayHostInfo.remote == nil {
+			relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				hm.f.Handshake(*relay)
+				hm.f.Handshake(relay)
 				continue
 			}
 			// Check the relay HostInfo to see if we already established a relay through it
@@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 				case Requested:
 					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+
+					//TODO: IPV6-WORK
+					myVpnIpB := hm.f.myVpnNet.Addr().As4()
+					theirVpnIpB := vpnIp.As4()
+
 					// Re-send the CreateRelay request, in case the previous one was lost.
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
-						RelayToIp:           uint32(vpnIp),
+						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
+						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
@@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 						// This must send over the hostinfo, not over hm.Hosts[ip]
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.lightHouse.myVpnIp,
+							"relayFrom":           hm.f.myVpnNet.Addr(),
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
-							"relay":               *relay}).
+							"relay":               relay}).
 							Info("send CreateRelayRequest")
 					}
 				default:
@@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 				}
 			} else {
 				// No relays exist or requested yet.
-				if relayHostInfo.remote != nil {
+				if relayHostInfo.remote.IsValid() {
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					if err != nil {
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 
+					//TODO: IPV6-WORK
+					myVpnIpB := hm.f.myVpnNet.Addr().As4()
+					theirVpnIpB := vpnIp.As4()
+
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
-						RelayToIp:           uint32(vpnIp),
+						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
+						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
@@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.lightHouse.myVpnIp,
+							"relayFrom":           hm.f.myVpnNet.Addr(),
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
-							"relay":               *relay}).
+							"relay":               relay}).
 							Info("send CreateRelayRequest")
 					}
 				}
@@ -355,7 +364,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 
 // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
 // The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
+func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	hm.mainHostMap.RLock()
 	h, ok := hm.mainHostMap.Hosts[vpnIp]
 	hm.mainHostMap.RUnlock()
@@ -372,7 +381,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 }
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 
 	if hh, ok := hm.vpnIps[vpnIp]; ok {
@@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 		vpnIp:           vpnIp,
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 	}
@@ -555,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
 func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	delete(c.vpnIps, hostinfo.vpnIp)
 	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
+		c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 
 	delete(c.indexes, hostinfo.localIndexId)
@@ -570,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 }
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 		return hh.hostinfo
@@ -579,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 
 }
 
-func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
 	hm.RLock()
 	defer hm.RUnlock()
 	return hm.vpnIps[vpnIp]
@@ -599,7 +608,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 }
 
-func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
+func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
 	return c.mainHostMap.GetPreferredRanges()
 }
 

+ 9 - 9
handshake_manager_test.go

@@ -1,13 +1,12 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"testing"
 	"time"
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
@@ -15,10 +14,11 @@ import (
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	ip := netip.MustParseAddr("172.1.1.2")
+
+	preferredRanges := []netip.Prefix{localrange}
 	mainHM := newHostMap(l, vpncidr)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
@@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.NotContains(t, blah.vpnIps, ip)
 }
 
-func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
+func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 	for _, i := range tw.t.wheel {
 		n := i.Head
 		for n != nil {
@@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
 type mockEncWriter struct {
 }
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
 	return
 }
 
@@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	return
 }
 
-func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}
+func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}

+ 75 - 71
hostmap.go

@@ -3,18 +3,17 @@ package nebula
 import (
 	"errors"
 	"net"
+	"net/netip"
 	"sync"
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 // const ProbeLen = 100
@@ -49,7 +48,7 @@ type Relay struct {
 	State       int
 	LocalIndex  uint32
 	RemoteIndex uint32
-	PeerIp      iputil.VpnIp
+	PeerIp      netip.Addr
 }
 
 type HostMap struct {
@@ -57,9 +56,9 @@ type HostMap struct {
 	Indexes         map[uint32]*HostInfo
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	RemoteIndexes   map[uint32]*HostInfo
-	Hosts           map[iputil.VpnIp]*HostInfo
-	preferredRanges atomic.Pointer[[]*net.IPNet]
-	vpnCIDR         *net.IPNet
+	Hosts           map[netip.Addr]*HostInfo
+	preferredRanges atomic.Pointer[[]netip.Prefix]
+	vpnCIDR         netip.Prefix
 	l               *logrus.Logger
 }
 
@@ -69,12 +68,12 @@ type HostMap struct {
 type RelayState struct {
 	sync.RWMutex
 
-	relays        map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[iputil.VpnIp]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay         // Maps a local index to some Relay info
+	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
+	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
 }
 
-func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
+func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
 	delete(rs.relays, ip)
@@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 }
 
-func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
 	r, ok := rs.relayForByIp[ip]
 	return r, ok
 }
 
-func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) {
+func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
 	rs.Lock()
 	defer rs.Unlock()
 	rs.relays[ip] = struct{}{}
 }
 
-func (rs *RelayState) CopyRelayIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayIps() []netip.Addr {
 	rs.RLock()
 	defer rs.RUnlock()
-	ret := make([]iputil.VpnIp, 0, len(rs.relays))
+	ret := make([]netip.Addr, 0, len(rs.relays))
 	for ip := range rs.relays {
 		ret = append(ret, ip)
 	}
 	return ret
 }
 
-func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	defer rs.RUnlock()
-	currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp))
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
 	for relayIp := range rs.relayForByIp {
 		currentRelays = append(currentRelays, relayIp)
 	}
@@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 	return ret
 }
 
-func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
-	rs.Lock()
-	defer rs.Unlock()
-	r, ok := rs.relayForByIdx[localIdx]
-	if !ok {
-		return iputil.VpnIp(0), false
-	}
-	delete(rs.relayForByIdx, localIdx)
-	delete(rs.relayForByIp, r.PeerIp)
-	return r.PeerIp, true
-}
-
-func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
+func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	defer rs.Unlock()
 	r, ok := rs.relayForByIp[vpnIp]
@@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	return &newRelay, true
 }
 
-func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
 	r, ok := rs.relayForByIp[vpnIp]
@@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 	return r, ok
 }
 
-func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
+func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	defer rs.Unlock()
 	rs.relayForByIp[ip] = r
@@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
 }
 
 type HostInfo struct {
-	remote          *udp.Addr
+	remote          netip.AddrPort
 	remotes         *RemoteList
 	promoteCounter  atomic.Uint32
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	localIndexId    uint32
-	vpnIp           iputil.VpnIp
+	vpnIp           netip.Addr
 	recvError       atomic.Uint32
-	remoteCidr      *cidr.Tree4[struct{}]
+	remoteCidr      *bart.Table[struct{}]
 	relayState      RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
@@ -227,7 +214,7 @@ type HostInfo struct {
 	lastHandshakeTime uint64
 
 	lastRoam       time.Time
-	lastRoamRemote *udp.Addr
+	lastRoamRemote netip.AddrPort
 
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Synchronised via hostmap lock and not the hostinfo lock.
@@ -254,7 +241,7 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
+func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
 	hm := newHostMap(l, vpnCIDR)
 
 	hm.reload(c, true)
@@ -269,12 +256,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho
 	return hm
 }
 
-func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
+func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
 	return &HostMap{
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
-		Hosts:         map[iputil.VpnIp]*HostInfo{},
+		Hosts:         map[netip.Addr]*HostInfo{},
 		vpnCIDR:       vpnCIDR,
 		l:             l,
 	}
@@ -282,11 +269,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
 
 func (hm *HostMap) reload(c *config.C, initial bool) {
 	if initial || c.HasChanged("preferred_ranges") {
-		var preferredRanges []*net.IPNet
+		var preferredRanges []netip.Prefix
 		rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
 
 		for _, rawPreferredRange := range rawPreferredRanges {
-			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
+			preferredRange, err := netip.ParsePrefix(rawPreferredRange)
 
 			if err != nil {
 				hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
@@ -378,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
 		delete(hm.Hosts, hostinfo.vpnIp)
 		if len(hm.Hosts) == 0 {
-			hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 
 		if hostinfo.next != nil {
@@ -461,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 }
 
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
 	return hm.queryVpnIp(vpnIp, nil)
 }
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	defer hm.RUnlock()
 
@@ -483,7 +470,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
 	return nil, nil, errors.New("unable to find host with relay")
 }
 
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -535,7 +522,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 }
 
-func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
+func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
 	return *hm.preferredRanges.Load()
 }
@@ -560,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
-func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
+func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
 	if c%ifce.tryPromoteEvery.Load() == 0 {
 		remote := i.remote
 
 		// return early if we are already on a preferred remote
-		if remote != nil {
-			rIP := remote.IP
+		if remote.IsValid() {
+			rIP := remote.Addr()
 			for _, l := range preferredRanges {
 				if l.Contains(rIP) {
 					return
@@ -575,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 		}
 
-		i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
-			if remote != nil && (addr == nil || !preferred) {
+		i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
+			if remote.IsValid() && (!addr.IsValid() || !preferred) {
 				return
 			}
 
@@ -605,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	return nil
 }
 
-func (i *HostInfo) SetRemote(remote *udp.Addr) {
+func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
-	if !i.remote.Equals(remote) {
-		i.remote = remote.Copy()
-		i.remotes.LearnRemote(i.vpnIp, remote.Copy())
+	if i.remote != remote {
+		i.remote = remote
+		i.remotes.LearnRemote(i.vpnIp, remote)
 	}
 }
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
-	if newRemote == nil {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
+	if !newRemote.IsValid() {
 		// relays have nil udp Addrs
 		return false
 	}
 	currentRemote := i.remote
-	if currentRemote == nil {
+	if !currentRemote.IsValid() {
 		i.SetRemote(newRemote)
 		return true
 	}
@@ -631,11 +618,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	newIsPreferred := false
 	for _, l := range hm.GetPreferredRanges() {
 		// return early if we are already on a preferred remote
-		if l.Contains(currentRemote.IP) {
+		if l.Contains(currentRemote.Addr()) {
 			return false
 		}
 
-		if l.Contains(newRemote.IP) {
+		if l.Contains(newRemote.Addr()) {
 			newIsPreferred = true
 		}
 	}
@@ -643,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	if newIsPreferred {
 		// Consider this a roaming event
 		i.lastRoam = time.Now()
-		i.lastRoamRemote = currentRemote.Copy()
+		i.lastRoamRemote = currentRemote
 
 		i.SetRemote(newRemote)
 
@@ -666,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 	}
 
-	remoteCidr := cidr.NewTree4[struct{}]()
+	remoteCidr := new(bart.Table[struct{}])
 	for _, ip := range c.Details.Ips {
-		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+		//TODO: IPV6-WORK what to do when ip is invalid?
+		nip, _ := netip.AddrFromSlice(ip.IP)
+		nip = nip.Unmap()
+		bits, _ := ip.Mask.Size()
+		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
 	}
 
 	for _, n := range c.Details.Subnets {
-		remoteCidr.AddCIDR(n, struct{}{})
+		//TODO: IPV6-WORK what to do when ip is invalid?
+		nip, _ := netip.AddrFromSlice(n.IP)
+		nip = nip.Unmap()
+		bits, _ := n.Mask.Size()
+		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
 	}
 	i.remoteCidr = remoteCidr
 }
@@ -697,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 // Utility functions
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
+func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
-	var ips []net.IP
+	var ips []netip.Addr
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
@@ -721,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
 				ip = v.IP
 			}
 
+			nip, ok := netip.AddrFromSlice(ip)
+			if !ok {
+				if l.Level >= logrus.DebugLevel {
+					l.WithField("localIp", ip).Debug("ip was invalid for netip")
+				}
+				continue
+			}
+			nip = nip.Unmap()
+
 			//TODO: Filtering out link local for now, this is probably the most correct thing
 			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
-				allow := allowList.Allow(ip)
+			if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
+				allow := allowList.Allow(nip)
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
 				}
 				if !allow {
 					continue
 				}
 
-				ips = append(ips, ip)
+				ips = append(ips, nip)
 			}
 		}
 	}
-	return &ips
+	return ips
 }

+ 25 - 34
hostmap_test.go

@@ -1,7 +1,7 @@
 package nebula
 
 import (
-	"net"
+	"net/netip"
 	"testing"
 
 	"github.com/slackhq/nebula/config"
@@ -13,18 +13,15 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
 	hm := newHostMap(
 		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
+		netip.MustParsePrefix("10.0.0.1/24"),
 	)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
-	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
-	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
-	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
 	hm := newHostMap(
 		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
+		netip.MustParsePrefix("10.0.0.1/24"),
 	)
 
 	f := &Interface{}
 
-	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
-	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
-	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
-	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
-	h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
-	h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+	h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
+	h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
@@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 }
 
@@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) {
 
 	hm := NewHostMapFromConfig(
 		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
+		netip.MustParsePrefix("10.0.0.1/24"),
 		c,
 	)
 
-	toS := func(ipn []*net.IPNet) []string {
+	toS := func(ipn []netip.Prefix) []string {
 		var s []string
 		for _, n := range ipn {
 			s = append(s, n.String())

+ 4 - 2
hostmap_tester.go

@@ -5,9 +5,11 @@ package nebula
 
 // This file contains functions used to export information to the e2e testing framework
 
-import "github.com/slackhq/nebula/iputil"
+import (
+	"net/netip"
+)
 
-func (i *HostInfo) GetVpnIp() iputil.VpnIp {
+func (i *HostInfo) GetVpnIp() netip.Addr {
 	return i.vpnIp
 }
 

+ 20 - 24
inside.go

@@ -1,12 +1,13 @@
 package nebula
 
 import (
+	"net/netip"
+
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
-	"github.com/slackhq/nebula/udp"
 )
 
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
+	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
 		return
 	}
 
-	if fwPacket.RemoteIP == f.myVpnIp {
+	if fwPacket.RemoteIP == f.myVpnNet.Addr() {
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// routes packets from the Nebula IP to the Nebula IP through the Nebula
@@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	// Ignore broadcast packets
-	if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
+	// Ignore multicast packets
+	if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
 		return
 	}
 
@@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
+		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
 
 	} else {
 		f.rejectInside(packet, out, q)
@@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 	}
 
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
+	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 
-func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
+func (f *Interface) Handshake(vpnIp netip.Addr) {
 	f.getOrHandshake(vpnIp, nil)
 }
 
 // getOrHandshake returns nil if the vpnIp is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
+func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	if !f.myVpnNet.Contains(vpnIp) {
 		vpnIp = f.inside.RouteFor(vpnIp)
-		if vpnIp == 0 {
+		if !vpnIp.IsValid() {
 			return nil, false
 		}
 	}
@@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 	}
 
-	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
 	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
@@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 
-func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
+func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
@@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo,
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
 		//TODO: log warning
 		return
 	}
-	useRelay := remote == nil && hostinfo.remote == nil
+	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
 	fullOut := out
 
 	if useRelay {
@@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		return
 	}
 
-	if remote != nil {
+	if remote.IsValid() {
 		err = f.writers[q].WriteTo(out, remote)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
 				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 		}
-	} else if hostinfo.remote != nil {
+	} else if hostinfo.remote.IsValid() {
 		err = f.writers[q].WriteTo(out, hostinfo.remote)
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
@@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		}
 	}
 }
-
-func isMulticast(ip iputil.VpnIp) bool {
-	// Class D multicast
-	return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
-}

+ 36 - 11
interface.go

@@ -2,10 +2,11 @@ package nebula
 
 import (
 	"context"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"runtime"
 	"sync/atomic"
@@ -16,7 +17,6 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 )
@@ -63,8 +63,8 @@ type Interface struct {
 	serveDns           bool
 	createTime         time.Time
 	lightHouse         *LightHouse
-	localBroadcast     iputil.VpnIp
-	myVpnIp            iputil.VpnIp
+	myBroadcastAddr    netip.Addr
+	myVpnNet           netip.Prefix
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	routines           int
@@ -102,9 +102,9 @@ type EncWriter interface {
 		out []byte,
 		nocopy bool,
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp iputil.VpnIp)
+	Handshake(vpnIp netip.Addr)
 }
 
 type sendRecvErrorConfig uint8
@@ -115,10 +115,10 @@ const (
 	sendRecvErrorPrivate
 )
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
 	switch s {
 	case sendRecvErrorPrivate:
-		return ip.IsPrivate()
+		return ip.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 		return true
 	case sendRecvErrorNever:
@@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	}
 
 	certificate := c.pki.GetCertState().Certificate
-	myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
+
+	myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+	if !ok {
+		return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
+	}
+
+	myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
+	if !ok {
+		return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
+	}
+
+	myVpnAddr = myVpnAddr.Unmap()
+	myVpnMask = myVpnMask.Unmap()
+
+	if myVpnAddr.BitLen() != myVpnMask.BitLen() {
+		return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
+	}
+
+	ones, _ := certificate.Details.Ips[0].Mask.Size()
+	myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
+
 	ifce := &Interface{
 		pki:                c.pki,
 		hostMap:            c.HostMap,
@@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		handshakeManager:   c.HandshakeManager,
 		createTime:         time.Now(),
 		lightHouse:         c.lightHouse,
-		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
 		routines:           c.routines,
 		version:            c.version,
 		writers:            make([]udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnIp:            myVpnIp,
+		myVpnNet:           myVpnNet,
 		relayManager:       c.relayManager,
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
@@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 	}
 
+	if myVpnAddr.Is4() {
+		addr := myVpnNet.Masked().Addr().As4()
+		binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
+		ifce.myBroadcastAddr = netip.AddrFrom4(addr)
+	}
+
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))

+ 2 - 0
iputil/packet.go

@@ -6,6 +6,8 @@ import (
 	"golang.org/x/net/ipv4"
 )
 
+//TODO: IPV6-WORK can probably delete this
+
 const (
 	// Need 96 bytes for the largest reject packet:
 	// - 20 byte ipv4 header

+ 0 - 93
iputil/util.go

@@ -1,93 +0,0 @@
-package iputil
-
-import (
-	"encoding/binary"
-	"fmt"
-	"net"
-	"net/netip"
-)
-
-type VpnIp uint32
-
-const maxIPv4StringLen = len("255.255.255.255")
-
-func (ip VpnIp) String() string {
-	b := make([]byte, maxIPv4StringLen)
-
-	n := ubtoa(b, 0, byte(ip>>24))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip>>16&255))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip>>8&255))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip&255))
-	return string(b[:n])
-}
-
-func (ip VpnIp) MarshalJSON() ([]byte, error) {
-	return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
-}
-
-func (ip VpnIp) ToIP() net.IP {
-	nip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(nip, uint32(ip))
-	return nip
-}
-
-func (ip VpnIp) ToNetIpAddr() netip.Addr {
-	var nip [4]byte
-	binary.BigEndian.PutUint32(nip[:], uint32(ip))
-	return netip.AddrFrom4(nip)
-}
-
-func Ip2VpnIp(ip []byte) VpnIp {
-	if len(ip) == 16 {
-		return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
-	}
-	return VpnIp(binary.BigEndian.Uint32(ip))
-}
-
-func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
-	addr, ok := netip.AddrFromSlice(ip)
-	if !ok {
-		return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
-	}
-	return addr, nil
-}
-
-func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
-	addr, err := ToNetIpAddr(ipNet.IP)
-	if err != nil {
-		return netip.Prefix{}, err
-	}
-	ones, bits := ipNet.Mask.Size()
-	if ones == 0 && bits == 0 {
-		return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
-	}
-	return netip.PrefixFrom(addr, ones), nil
-}
-
-// ubtoa encodes the string form of the integer v to dst[start:] and
-// returns the number of bytes written to dst. The caller must ensure
-// that dst has sufficient length.
-func ubtoa(dst []byte, start int, v byte) int {
-	if v < 10 {
-		dst[start] = v + '0'
-		return 1
-	} else if v < 100 {
-		dst[start+1] = v%10 + '0'
-		dst[start] = v/10 + '0'
-		return 2
-	}
-
-	dst[start+2] = v%10 + '0'
-	dst[start+1] = (v/10)%10 + '0'
-	dst[start] = v/100 + '0'
-	return 3
-}

+ 0 - 17
iputil/util_test.go

@@ -1,17 +0,0 @@
-package iputil
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestVpnIp_String(t *testing.T) {
-	assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
-	assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
-	assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
-	assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
-	assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
-	assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
-}

+ 212 - 188
lighthouse.go

@@ -7,16 +7,16 @@ import (
 	"fmt"
 	"net"
 	"net/netip"
+	"strconv"
 	"sync"
 	"sync/atomic"
 	"time"
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
 )
@@ -26,25 +26,18 @@ import (
 
 var ErrHostNotKnown = errors.New("host not known")
 
-type netIpAndPort struct {
-	ip   net.IP
-	port uint16
-}
-
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
 	ctx          context.Context
 	amLighthouse bool
-	myVpnIp      iputil.VpnIp
-	myVpnZeros   iputil.VpnIp
-	myVpnNet     *net.IPNet
+	myVpnNet     netip.Prefix
 	punchConn    udp.Conn
 	punchy       *Punchy
 
 	// Local cache of answers from light houses
 	// map of vpn Ip to answers
-	addrMap map[iputil.VpnIp]*RemoteList
+	addrMap map[netip.Addr]*RemoteList
 
 	// filters remote addresses allowed for each host
 	// - When we are a lighthouse, this filters what addresses we store and
@@ -57,26 +50,26 @@ type LightHouse struct {
 	localAllowList atomic.Pointer[LocalAllowList]
 
 	// used to trigger the HandshakeManager when we receive HostQueryReply
-	handshakeTrigger chan<- iputil.VpnIp
+	handshakeTrigger chan<- netip.Addr
 
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
-	staticList  atomic.Pointer[map[iputil.VpnIp]struct{}]
-	lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
+	staticList  atomic.Pointer[map[netip.Addr]struct{}]
+	lighthouses atomic.Pointer[map[netip.Addr]struct{}]
 
 	interval     atomic.Int64
 	updateCancel context.CancelFunc
 	ifce         EncWriter
 	nebulaPort   uint32 // 32 bits because protobuf does not have a uint16
 
-	advertiseAddrs atomic.Pointer[[]netIpAndPort]
+	advertiseAddrs atomic.Pointer[[]netip.AddrPort]
 
 	// IP's of relays that can be used by peers to access me
-	relaysForMe atomic.Pointer[[]iputil.VpnIp]
+	relaysForMe atomic.Pointer[[]netip.Addr]
 
-	queryChan chan iputil.VpnIp
+	queryChan chan netip.Addr
 
-	calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
+	calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
 
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
@@ -85,7 +78,7 @@ type LightHouse struct {
 
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) {
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	if amLighthouse && nebulaPort == 0 {
@@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		if err != nil {
 			return nil, util.NewContextualError("Failed to get listening port", nil, err)
 		}
-		nebulaPort = uint32(uPort.Port)
+		nebulaPort = uint32(uPort.Port())
 	}
 
-	ones, _ := myVpnNet.Mask.Size()
 	h := LightHouse{
 		ctx:          ctx,
 		amLighthouse: amLighthouse,
-		myVpnIp:      iputil.Ip2VpnIp(myVpnNet.IP),
-		myVpnZeros:   iputil.VpnIp(32 - ones),
 		myVpnNet:     myVpnNet,
-		addrMap:      make(map[iputil.VpnIp]*RemoteList),
+		addrMap:      make(map[netip.Addr]*RemoteList),
 		nebulaPort:   nebulaPort,
 		punchConn:    pc,
 		punchy:       p,
-		queryChan:    make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
+		queryChan:    make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
 		l:            l,
 	}
-	lighthouses := make(map[iputil.VpnIp]struct{})
+	lighthouses := make(map[netip.Addr]struct{})
 	h.lighthouses.Store(&lighthouses)
-	staticList := make(map[iputil.VpnIp]struct{})
+	staticList := make(map[netip.Addr]struct{})
 	h.staticList.Store(&staticList)
 
 	if c.GetBool("stats.lighthouse_metrics", false) {
@@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 	return &h, nil
 }
 
-func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
 	return *lh.staticList.Load()
 }
 
-func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
 	return *lh.lighthouses.Load()
 }
 
@@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
 	return lh.localAllowList.Load()
 }
 
-func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort {
+func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort {
 	return *lh.advertiseAddrs.Load()
 }
 
-func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
+func (lh *LightHouse) GetRelaysForMe() []netip.Addr {
 	return *lh.relaysForMe.Load()
 }
 
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
+func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] {
 	return lh.calculatedRemotes.Load()
 }
 
@@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 {
 func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	if initial || c.HasChanged("lighthouse.advertise_addrs") {
 		rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
-		advAddrs := make([]netIpAndPort, 0)
+		advAddrs := make([]netip.AddrPort, 0)
 
 		for i, rawAddr := range rawAdvAddrs {
-			fIp, fPort, err := udp.ParseIPAndPort(rawAddr)
+			host, sport, err := net.SplitHostPort(rawAddr)
 			if err != nil {
 				return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
 			}
 
-			if fPort == 0 {
-				fPort = uint16(lh.nebulaPort)
+			ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
+			if err != nil {
+				return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+			}
+			if len(ips) == 0 {
+				return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil)
+			}
+
+			port, err := strconv.Atoi(sport)
+			if err != nil {
+				return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+			}
+
+			if port == 0 {
+				port = int(lh.nebulaPort)
 			}
 
-			if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) {
+			//TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used
+			ip := ips[0].Unmap()
+			if lh.myVpnNet.Contains(ip) {
 				lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
 					Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
 				continue
 			}
 
-			advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort})
+			advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port)))
 		}
 
 		lh.advertiseAddrs.Store(&advAddrs)
@@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 			lh.RUnlock()
 		}
 		// Build a new list based on current config.
-		staticList := make(map[iputil.VpnIp]struct{})
-		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
+		staticList := make(map[netip.Addr]struct{})
+		err := lh.loadStaticMap(c, staticList)
 		if err != nil {
 			return err
 		}
@@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 
 	if initial || c.HasChanged("lighthouse.hosts") {
-		lhMap := make(map[iputil.VpnIp]struct{})
-		err := lh.parseLighthouses(c, lh.myVpnNet, lhMap)
+		lhMap := make(map[netip.Addr]struct{})
+		err := lh.parseLighthouses(c, lhMap)
 		if err != nil {
 			return err
 		}
@@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 			if len(c.GetStringSlice("relay.relays", nil)) > 0 {
 				lh.l.Info("Ignoring relays from config because am_relay is true")
 			}
-			relaysForMe := []iputil.VpnIp{}
+			relaysForMe := []netip.Addr{}
 			lh.relaysForMe.Store(&relaysForMe)
 		case false:
-			relaysForMe := []iputil.VpnIp{}
+			relaysForMe := []netip.Addr{}
 			for _, v := range c.GetStringSlice("relay.relays", nil) {
 				lh.l.WithField("relay", v).Info("Read relay from config")
 
-				configRIP := net.ParseIP(v)
-				if configRIP != nil {
-					relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP))
+				configRIP, err := netip.ParseAddr(v)
+				//TODO: We could print the error here
+				if err == nil {
+					relaysForMe = append(relaysForMe, configRIP)
 				}
 			}
 			lh.relaysForMe.Store(&relaysForMe)
@@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
 	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
 	if lh.amLighthouse && len(lhs) != 0 {
 		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
 
 	for i, host := range lhs {
-		ip := net.ParseIP(host)
-		if ip == nil {
-			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+		ip, err := netip.ParseAddr(host)
+		if err != nil {
+			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
 		}
-		if !tunCidr.Contains(ip) {
-			return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+		if !lh.myVpnNet.Contains(ip) {
+			return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil)
 		}
-		lhMap[iputil.Ip2VpnIp(ip)] = struct{}{}
+		lhMap[ip] = struct{}{}
 	}
 
 	if !lh.amLighthouse && len(lhMap) == 0 {
@@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) {
 	return network, nil
 }
 
-func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error {
 	d, err := getStaticMapCadence(c)
 	if err != nil {
 		return err
@@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 		return err
 	}
 
-	lookup_timeout, err := getStaticMapLookupTimeout(c)
+	lookupTimeout, err := getStaticMapLookupTimeout(c)
 	if err != nil {
 		return err
 	}
@@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	i := 0
 
 	for k, v := range shm {
-		rip := net.ParseIP(fmt.Sprintf("%v", k))
-		if rip == nil {
-			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil)
+		vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k))
+		if err != nil {
+			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
 		}
 
-		if !tunCidr.Contains(rip) {
-			return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil)
+		if !lh.myVpnNet.Contains(vpnIp) {
+			return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil)
 		}
 
-		vpnIp := iputil.Ip2VpnIp(rip)
 		vals, ok := v.([]interface{})
 		if !ok {
 			vals = []interface{}{v}
@@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 			remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
 		}
 
-		err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+		err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList)
 		if err != nil {
 			return err
 		}
@@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	return nil
 }
 
-func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) Query(ip netip.Addr) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip)
 	}
@@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
 }
 
 // QueryServer is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
+func (lh *LightHouse) QueryServer(ip netip.Addr) {
 	// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
 	if lh.amLighthouse || lh.IsLighthouseIP(ip) {
 		return
@@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
 	lh.queryChan <- ip
 }
 
-func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList {
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
@@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
 // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
 // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
-func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
+func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) {
 	lh.RLock()
 	// Do we have an entry in the main cache?
 	if v, ok := lh.addrMap[vpnIp]; ok {
@@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
 	return false, 0, nil
 }
 
-func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
+func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) {
 	// First we check the static mapping
 	// and do nothing if it is there
 	if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
@@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error {
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
@@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
 	am.unlockedSetHostnamesResults(hr)
 
 	for _, addrPort := range hr.GetIPs() {
-
+		if !lh.shouldAdd(vpnIp, addrPort.Addr()) {
+			continue
+		}
 		switch {
 		case addrPort.Addr().Is4():
-			to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
-			if !lh.unlockedShouldAddV4(vpnIp, to) {
-				continue
-			}
-			am.unlockedPrependV4(lh.myVpnIp, to)
+			am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
 		case addrPort.Addr().Is6():
-			to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
-			if !lh.unlockedShouldAddV6(vpnIp, to) {
-				continue
-			}
-			am.unlockedPrependV6(lh.myVpnIp, to)
+			am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
 		}
 	}
 
@@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
 // addCalculatedRemotes adds any calculated remotes based on the
 // lighthouse.calculated_remotes configuration. It returns true if any
 // calculated remotes were added
-func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool {
 	tree := lh.getCalculatedRemotes()
 	if tree == nil {
 		return false
 	}
-	ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+	calculatedRemotes, ok := tree.Lookup(vpnIp)
 	if !ok {
 		return false
 	}
@@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 	defer am.Unlock()
 	lh.Unlock()
 
-	am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4)
+	am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4)
 
 	return len(calculated) > 0
 }
 
 // unlockedGetRemoteList assumes you have the lh lock
-func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList {
 	am, ok := lh.addrMap[vpnIp]
 	if !ok {
 		am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
@@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 	return am
 }
 
-func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
-	switch {
-	case to.Is4():
-		ipBytes := to.As4()
-		ip := iputil.Ip2VpnIp(ipBytes[:])
-		allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
-		if lh.l.Level >= logrus.TraceLevel {
-			lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
-		}
-		if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
-			return false
-		}
-	case to.Is6():
-		ipBytes := to.As16()
-
-		hi := binary.BigEndian.Uint64(ipBytes[:8])
-		lo := binary.BigEndian.Uint64(ipBytes[8:])
-		allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
-		if lh.l.Level >= logrus.TraceLevel {
-			lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
-		}
-
-		// We don't check our vpn network here because nebula does not support ipv6 on the inside
-		if !allow {
-			return false
-		}
+func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool {
+	allow := lh.GetRemoteAllowList().Allow(vpnIp, to)
+	if lh.l.Level >= logrus.TraceLevel {
+		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+	}
+	if !allow || lh.myVpnNet.Contains(to) {
+		return false
 	}
+
 	return true
 }
 
 // unlockedShouldAddV4 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
-	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
+func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool {
+	ip := AddrPortFromIp4AndPort(to)
+	allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 
-	if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
+	if !allow || lh.myVpnNet.Contains(ip.Addr()) {
 		return false
 	}
 
@@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
 }
 
 // unlockedShouldAddV6 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
-	allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo)
+func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool {
+	ip := AddrPortFromIp6AndPort(to)
+	allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 
-	// We don't check our vpn network here because nebula does not support ipv6 on the inside
-	if !allow {
+	if !allow || lh.myVpnNet.Contains(ip.Addr()) {
 		return false
 	}
 
@@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 	return ip
 }
 
-func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool {
 	if _, ok := lh.GetLighthouses()[vpnIp]; ok {
 		return true
 	}
 	return false
 }
 
-func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
+func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta {
+	if vpnIp.Is6() {
+		//TODO: need to support ipv6
+		panic("ipv6 is not yet supported")
+	}
+
+	b := vpnIp.As4()
 	return &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
-			VpnIp: uint32(VpnIp),
+			VpnIp: binary.BigEndian.Uint32(b[:]),
 		},
 	}
 }
 
-func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
-	ipp := Ip4AndPort{Port: port}
-	ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
-	return &ipp
+func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort {
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], ip.Ip)
+	return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port))
+}
+
+func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort {
+	b := [16]byte{}
+	binary.BigEndian.PutUint64(b[:8], ip.Hi)
+	binary.BigEndian.PutUint64(b[8:], ip.Lo)
+	return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port))
 }
 
 func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
@@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
 	}
 }
 
-func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
-	return &Ip6AndPort{
-		Hi:   binary.BigEndian.Uint64(ip[:8]),
-		Lo:   binary.BigEndian.Uint64(ip[8:]),
-		Port: port,
-	}
-}
-
+// TODO: IPV6-WORK we can delete some more of these
 func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
 	ip6Addr := ip.As16()
 	return &Ip6AndPort{
@@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
 		Port: uint32(port),
 	}
 }
-func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
-	ip := ipp.Ip
-	return udp.NewAddr(
-		net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
-		uint16(ipp.Port),
-	)
-}
-
-func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
-	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
-}
 
 func (lh *LightHouse) startQueryWorker() {
 	if lh.amLighthouse {
@@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() {
 	}()
 }
 
-func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
+func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) {
 	if lh.IsLighthouseIP(ip) {
 		return
 	}
@@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() {
 	var v6 []*Ip6AndPort
 
 	for _, e := range lh.GetAdvertiseAddrs() {
-		if ip := e.ip.To4(); ip != nil {
-			v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port)))
+		if e.Addr().Is4() {
+			v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port()))
 		} else {
-			v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port)))
+			v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port()))
 		}
 	}
 
 	lal := lh.GetLocalAllowList()
-	for _, e := range *localIps(lh.l, lal) {
-		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
+	for _, e := range localIps(lh.l, lal) {
+		if lh.myVpnNet.Contains(e) {
 			continue
 		}
 
 		// Only add IPs that aren't my VPN/tun IP
-		if ip := e.To4(); ip != nil {
-			v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort))
+		if e.Is4() {
+			v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort)))
 		} else {
-			v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort))
+			v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort)))
 		}
 	}
 
 	var relays []uint32
 	for _, r := range lh.GetRelaysForMe() {
-		relays = append(relays, (uint32)(r))
+		//TODO: IPV6-WORK both relays and vpnip need ipv6 support
+		b := r.As4()
+		relays = append(relays, binary.BigEndian.Uint32(b[:]))
 	}
 
+	//TODO: IPV6-WORK both relays and vpnip need ipv6 support
+	b := lh.myVpnNet.Addr().As4()
+
 	m := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
-			VpnIp:       uint32(lh.myVpnIp),
+			VpnIp:       binary.BigEndian.Uint32(b[:]),
 			Ip4AndPorts: v4,
 			Ip6AndPorts: v6,
 			RelayVpnIp:  relays,
@@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 }
 
 func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
-	return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
+	return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) {
 		lhh.HandleRequest(rAddr, vpnIp, p, f)
 	}
 }
 
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
+func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) {
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	if err != nil {
@@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) {
 	// Exit if we don't answer queries
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
@@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
 
 	//TODO: we can DRY this further
 	reqVpnIp := n.Details.VpnIp
+
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+	queryVpnIp := netip.AddrFrom4(b)
+
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
-	found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
+	found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		n.Details.VpnIp = reqVpnIp
@@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
 	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
-		n.Details.VpnIp = uint32(vpnIp)
-
+		//TODO: IPV6-WORK
+		b = vpnIp.As4()
+		n.Details.VpnIp = binary.BigEndian.Uint32(b[:])
 		lhh.coalesceAnswers(c, n)
 
 		return n.MarshalTo(lhh.pb)
@@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
 	}
 
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
-	w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
+
+	//TODO: IPV6-WORK
+	binary.BigEndian.PutUint32(b[:], reqVpnIp)
+	sendTo := netip.AddrFrom4(b)
+	w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 
 func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
 	}
 
 	if c.relay != nil {
-		n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...)
+		//TODO: IPV6-WORK
+		relays := make([]uint32, len(c.relay.relay))
+		b := [4]byte{}
+		for i, _ := range relays {
+			b = c.relay.relay[i].As4()
+			relays[i] = binary.BigEndian.Uint32(b[:])
+		}
+		n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...)
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
+func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 	}
 
 	lhh.lh.Lock()
-	am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+	certVpnIp := netip.AddrFrom4(b)
+	am := lhh.lh.unlockedGetRemoteList(certVpnIp)
 	am.Lock()
 	lhh.lh.Unlock()
 
-	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+	//TODO: IPV6-WORK
 	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
 	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
-	am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+
+	//TODO: IPV6-WORK
+	relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+	for i, _ := range n.Details.RelayVpnIp {
+		binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+		relays[i] = netip.AddrFrom4(b)
+	}
+	am.unlockedSetRelay(vpnIp, certVpnIp, relays)
 	am.Unlock()
 
 	// Non-blocking attempt to trigger, skip if it would block
 	select {
-	case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
+	case lhh.lh.handshakeTrigger <- certVpnIp:
 	default:
 	}
 }
 
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	}
 
 	//Simple check that the host sent this not someone else
-	if n.Details.VpnIp != uint32(vpnIp) {
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+	detailsVpnIp := netip.AddrFrom4(b)
+	if detailsVpnIp != vpnIp {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+			lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
 		}
 		return
 	}
@@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.Lock()
 	lhh.lh.Unlock()
 
-	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
-	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
-	am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+	am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+
+	//TODO: IPV6-WORK
+	relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+	for i, _ := range n.Details.RelayVpnIp {
+		binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+		relays[i] = netip.AddrFrom4(b)
+	}
+	am.unlockedSetRelay(vpnIp, detailsVpnIp, relays)
 	am.Unlock()
 
 	n = lhh.resetMeta()
 	n.Type = NebulaMeta_HostUpdateNotificationAck
-	n.Details.VpnIp = uint32(vpnIp)
+
+	//TODO: IPV6-WORK
+	vpnIpB := vpnIp.As4()
+	n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:])
 	ln, err := n.MarshalTo(lhh.pb)
 
 	if err != nil {
@@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 	}
 
 	empty := []byte{0}
-	punch := func(vpnPeer *udp.Addr) {
-		if vpnPeer == nil {
+	punch := func(vpnPeer netip.AddrPort) {
+		if !vpnPeer.IsValid() {
 			return
 		}
 
@@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
 
 		if lhh.l.Level >= logrus.DebugLevel {
 			//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
-			lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
+			//TODO: IPV6-WORK, make this debug line not suck
+			b := [4]byte{}
+			binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+			lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b))
 		}
 	}
 
 	for _, a := range n.Details.Ip4AndPorts {
-		punch(NewUDPAddrFromLH4(a))
+		punch(AddrPortFromIp4AndPort(a))
 	}
 
 	for _, a := range n.Details.Ip6AndPorts {
-		punch(NewUDPAddrFromLH6(a))
+		punch(AddrPortFromIp6AndPort(a))
 	}
 
 	// This sends a nebula test packet to the host trying to contact us. In the case
 	// of a double nat or other difficult scenario, this may help establish
 	// a tunnel.
 	if lhh.lh.punchy.GetRespond() {
-		queryVpnIp := iputil.VpnIp(n.Details.VpnIp)
+		//TODO: IPV6-WORK
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+		queryVpnIp := netip.AddrFrom4(b)
 		go func() {
 			time.Sleep(lhh.lh.punchy.GetRespondDelay())
 			if lhh.l.Level >= logrus.DebugLevel {
@@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
 		}()
 	}
 }
-
-// ipMaskContains checks if testIp is contained by ip after applying a cidr.
-// zeros is 32 - bits from net.IPMask.Size()
-func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
-	return (testIp^ip)>>zeros == 0
-}

+ 85 - 102
lighthouse_test.go

@@ -2,15 +2,14 @@ package nebula
 
 import (
 	"context"
+	"encoding/binary"
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
-	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"gopkg.in/yaml.v2"
 )
@@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) {
 	var m Ip4AndPort
 	err := m.Unmarshal(b)
 	assert.NoError(t, err)
-	assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
+	ip := netip.MustParseAddr("10.1.1.1")
+	bp := ip.As4()
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
 }
 
 func TestNewLhQuery(t *testing.T) {
-	myIp := net.ParseIP("192.1.1.1")
-	myIpint := iputil.Ip2VpnIp(myIp)
+	myIp, err := netip.ParseAddr("192.1.1.1")
+	assert.NoError(t, err)
 
 	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIpint)
+	a := NewLhQueryByInt(myIp)
 
 	// The result should be a nebulameta protobuf
 	assert.IsType(t, &NebulaMeta{}, a)
@@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) {
 
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
@@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) {
 
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	lh1 := "10.128.0.2"
 
 	c := config.NewC(l)
@@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	lh.ifce = &mockEncWriter{}
 
 	// The first one routine is kicked off by main.go currently, lets make sure that one dies
-	c.ReloadConfigString("lighthouse:\n  interval: 5")
+	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 5"))
 	assert.Equal(t, int64(5), lh.interval.Load())
 
 	// Subsequent calls are killed off by the LightHouse.Reload function
-	c.ReloadConfigString("lighthouse:\n  interval: 10")
+	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 10"))
 	assert.Equal(t, int64(10), lh.interval.Load())
 
 	// If this completes then nothing is stealing our reload routine
-	c.ReloadConfigString("lighthouse:\n  interval: 11")
+	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 11"))
 	assert.Equal(t, int64(11), lh.interval.Load())
 }
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
 
 	c := config.NewC(l)
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
@@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		b.Fatal()
 	}
 
-	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
-	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = NewRemoteList(nil)
-	lh.addrMap[3].unlockedSetV4(
-		3,
-		3,
+	hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
+	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
+
+	vpnIp3 := netip.MustParseAddr("0.0.0.3")
+	lh.addrMap[vpnIp3] = NewRemoteList(nil)
+	lh.addrMap[vpnIp3].unlockedSetV4(
+		vpnIp3,
+		vpnIp3,
 		[]*Ip4AndPort{
-			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
-			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+			NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
+			NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 
-	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
-	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = NewRemoteList(nil)
-	lh.addrMap[2].unlockedSetV4(
-		3,
-		3,
+	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
+	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
+	vpnIp2 := netip.MustParseAddr("0.0.0.3")
+	lh.addrMap[vpnIp2] = NewRemoteList(nil)
+	lh.addrMap[vpnIp2].unlockedSetV4(
+		vpnIp3,
+		vpnIp3,
 		[]*Ip4AndPort{
-			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
-			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+			NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
+			NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 
 	mw := &mockEncWriter{}
@@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, 2, p, mw)
+			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
 		}
 	})
 	b.Run("found", func(b *testing.B) {
@@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		assert.NoError(b, err)
 
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, 2, p, mw)
+			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
 		}
 	})
 }
@@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func TestLighthouse_Memory(t *testing.T) {
 	l := test.NewLogger()
 
-	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
-	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
-	myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
-	myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
-	myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
-	myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
-	myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
-	myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
-	myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
-	myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
-	myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
-	myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
-	myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
-
-	theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
-	theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
-	theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
-	theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
-	theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
-	theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+	myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
+	myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
+	myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
+	myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
+	myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
+	myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
+	myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
+	myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
+	myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
+	myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
+	myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
+	myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
+	myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+	theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
+	theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
+	theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
+	theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
+	theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
+	theirVpnIp := netip.MustParseAddr("10.128.0.3")
 
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 
 	// Test that my first update responds with just that
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
 
 	// Ensure we don't accumulate addresses
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
 
 	// Grow it back to 2
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
 	// Update a different host and ask about it
-	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
@@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	newLHHostUpdate(
 		myUdpAddr0,
 		myVpnIp,
-		[]*udp.Addr{
+		[]netip.AddrPort{
 			myUdpAddr1,
 			myUdpAddr2,
 			myUdpAddr3,
@@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) {
 	)
 
 	// Make sure we won't add ips in our vpn network
-	bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
-	bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
-	good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
+	bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
+	bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
+	good := netip.MustParseAddrPort("1.128.0.99:4242")
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 }
@@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
 	assert.NoError(t, err)
 
 	nc := map[interface{}]interface{}{
@@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) {
 	assert.NoError(t, err)
 }
 
-func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
+	//TODO: IPV6-WORK
+	bip := queryVpnIp.As4()
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
-			VpnIp: uint32(queryVpnIp),
+			VpnIp: binary.BigEndian.Uint32(bip[:]),
 		},
 	}
 
@@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
 	return w.lastReply
 }
 
-func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
+	//TODO: IPV6-WORK
+	bip := vpnIp.As4()
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
-			VpnIp:       uint32(vpnIp),
+			VpnIp:       binary.BigEndian.Uint32(bip[:]),
 			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
 		},
 	}
 
 	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
+		req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
 	}
 
 	b, err := req.Marshal()
@@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
 //	)
 //}
 
-func Test_ipMaskContains(t *testing.T) {
-	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
-	assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-}
-
 type testLhReply struct {
 	nebType    header.MessageType
 	nebSubType header.MessageSubType
-	vpnIp      iputil.VpnIp
+	vpnIp      netip.Addr
 	msg        *NebulaMeta
 }
 
@@ -414,7 +416,7 @@ type testEncWriter struct {
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 }
-func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
+func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
 }
 
 func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
@@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 }
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 }
 
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
-	if !assert.Len(t, have, len(want)) {
-		return
-	}
-
-	for k, w := range want {
-		if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
-			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
-		}
-	}
-}
-
-// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
+func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 		return
 	}
 
 	for k, w := range want {
-		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
-			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+		//TODO: IPV6-WORK
+		h := AddrPortFromIp4AndPort(have[k])
+		if !(h == w) {
+			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 	}
 }
-
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
-	addrs := make([]*udp.Addr, len(ips))
-	for k, v := range ips {
-		addrs[k] = NewUDPAddrFromLH4(v)
-	}
-	return addrs
-}

+ 22 - 8
main.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"net"
+	"net/netip"
 	"time"
 
 	"github.com/sirupsen/logrus"
@@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
-	// TODO: make sure mask is 4 bytes
-	tunCidr := certificate.Details.Ips[0]
+	ones, _ := certificate.Details.Ips[0].Mask.Size()
+	addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+	if !ok {
+		err = util.NewContextualError(
+			"Invalid ip address in certificate",
+			m{"vpnIp": certificate.Details.Ips[0].IP},
+			nil,
+		)
+		return nil, err
+	}
+	tunCidr := netip.PrefixFrom(addr, ones)
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
@@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	if !configTest {
 		rawListenHost := c.GetString("listen.host", "0.0.0.0")
-		var listenHost *net.IPAddr
+		var listenHost netip.Addr
 		if rawListenHost == "[::]" {
 			// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
-			listenHost = &net.IPAddr{IP: net.IPv6zero}
+			listenHost = netip.IPv6Unspecified()
 
 		} else {
-			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
+			ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
 			if err != nil {
 				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
 			}
+			if len(ips) == 0 {
+				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
+			}
+			listenHost = ips[0].Unmap()
 		}
 
 		for i := 0; i < routines; i++ {
-			l.Infof("listening %q %d", listenHost.IP, port)
-			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
+			l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
+			udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
@@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 				if err != nil {
 					return nil, util.NewContextualError("Failed to get listening port", nil, err)
 				}
-				port = int(uPort.Port)
+				port = int(uPort.Port())
 			}
 		}
 	}

+ 48 - 47
outside.go

@@ -4,6 +4,7 @@ import (
 	"encoding/binary"
 	"errors"
 	"fmt"
+	"net/netip"
 	"time"
 
 	"github.com/flynn/noise"
@@ -11,7 +12,6 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 	"google.golang.org/protobuf/proto"
@@ -21,9 +21,10 @@ const (
 	minFwPacketLen = 4
 )
 
+// TODO: IPV6-WORK this can likely be removed now
 func readOutsidePackets(f *Interface) udp.EncReader {
 	return func(
-		addr *udp.Addr,
+		addr netip.AddrPort,
 		out []byte,
 		packet []byte,
 		header *header.H,
@@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader {
 	}
 }
 
-func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
 		// TODO: best if we return this and let caller log
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
-			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
+			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
 		}
 		return
 	}
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
-	if addr != nil {
-		if ip4 := addr.IP.To4(); ip4 != nil {
-			if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
-				if f.l.Level >= logrus.DebugLevel {
-					f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
-				}
-				return
+	if ip.IsValid() {
+		if f.myVpnNet.Contains(ip.Addr()) {
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
+			return
 		}
 	}
 
@@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 	switch h.Type {
 	case header.Message:
 		// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 		}
 
@@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			// Successfully validated the thing. Get rid of the Relay header.
 			signedPayload = signedPayload[header.Len:]
 			// Pull the Roaming parts up here, and return in all call paths.
-			f.handleHostRoaming(hostinfo, addr)
+			f.handleHostRoaming(hostinfo, ip)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			f.connectionManager.In(hostinfo.localIndexId)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			case TerminalType:
 				// If I am the target of this relay, process the unwrapped packet
 				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-				f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+				f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
@@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 	case header.LightHouse:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
 
@@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			return
 		}
 
-		lhf(addr, hostinfo.vpnIp, d)
+		lhf(ip, hostinfo.vpnIp, d)
 
 		// Fallthrough to the bottom to record incoming traffic
 
 	case header.Test:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
 
@@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 		if h.Subtype == header.TestRequest {
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
-			f.handleHostRoaming(hostinfo, addr)
+			f.handleHostRoaming(hostinfo, ip)
 			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
 		}
 
@@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handshakeManager.HandleIncoming(addr, via, packet, h)
+		f.handshakeManager.HandleIncoming(ip, via, packet, h)
 		return
 
 	case header.RecvError:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handleRecvError(addr, h)
+		f.handleRecvError(ip, h)
 		return
 
 	case header.CloseTunnel:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", addr).
+		hostinfo.logger(f.l).WithField("udpAddr", ip).
 			Info("Close tunnel received, tearing down.")
 
 		f.closeTunnel(hostinfo)
 		return
 
 	case header.Control:
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				Error("Failed to decrypt Control packet")
 			return
@@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
+		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
 		return
 	}
 
-	f.handleHostRoaming(hostinfo, addr)
+	f.handleHostRoaming(hostinfo, ip)
 
 	f.connectionManager.In(hostinfo.localIndexId)
 }
@@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
-	if addr != nil && !hostinfo.remote.Equals(addr) {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
-			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
+	if ip.IsValid() && hostinfo.remote != ip {
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
-		if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+		if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(addr)
+		hostinfo.SetRemote(ip)
 	}
 
 }
 
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
 	// If connectionstate exists and the replay protector allows, process packet
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
 	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
-		if addr != nil {
+		if addr.IsValid() {
 			f.maybeSendRecvError(addr, h.RemoteIndex)
 			return false
 		} else {
@@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 
 	// Firewall packets are locally oriented
 	if incoming {
-		fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
-		fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
+		//TODO: IPV6-WORK
+		fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
+		fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
@@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 	} else {
-		fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
-		fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
+		//TODO: IPV6-WORK
+		fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
+		fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.LocalPort = 0
@@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	return true
 }
 
-func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {
-	if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) {
+func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
+	if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) {
 		f.sendRecvError(endpoint, index)
 	}
 }
 
-func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
+func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
 	//TODO: this should be a signed message so we can trust that we should drop the index
@@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
 	}
 }
 
-func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
+func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", h.RemoteIndex).
 			WithField("udpAddr", addr).
@@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 		return
 	}
 
-	if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
+	if hostinfo.remote.IsValid() && hostinfo.remote != addr {
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return
 	}

+ 5 - 5
outside_test.go

@@ -2,10 +2,10 @@ package nebula
 
 import (
 	"net"
+	"net/netip"
 	"testing"
 
 	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
 )
@@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) {
 
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
-	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
+	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
+	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.LocalPort, uint16(4))
 
@@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) {
 
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
-	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
+	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
+	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.LocalPort, uint16(5))
 }

+ 3 - 5
overlay/device.go

@@ -2,16 +2,14 @@ package overlay
 
 import (
 	"io"
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 
 type Device interface {
 	io.ReadWriteCloser
 	Activate() error
-	Cidr() *net.IPNet
+	Cidr() netip.Prefix
 	Name() string
-	RouteFor(iputil.VpnIp) iputil.VpnIp
+	RouteFor(netip.Addr) netip.Addr
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }

+ 19 - 25
overlay/route.go

@@ -1,34 +1,30 @@
 package overlay
 
 import (
-	"bytes"
 	"fmt"
 	"math"
 	"net"
+	"net/netip"
 	"runtime"
 	"strconv"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type Route struct {
 	MTU     int
 	Metric  int
-	Cidr    *net.IPNet
-	Via     *iputil.VpnIp
+	Cidr    netip.Prefix
+	Via     netip.Addr
 	Install bool
 }
 
 // Equal determines if a route that could be installed in the system route table is equal to another
 // Via is ignored since that is only consumed within nebula itself
 func (r Route) Equal(t Route) bool {
-	if !r.Cidr.IP.Equal(t.Cidr.IP) {
-		return false
-	}
-	if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
+	if r.Cidr != t.Cidr {
 		return false
 	}
 	if r.Metric != t.Metric {
@@ -51,21 +47,21 @@ func (r Route) String() string {
 	return s
 }
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
-	routeTree := cidr.NewTree4[iputil.VpnIp]()
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
+	routeTree := new(bart.Table[netip.Addr])
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
 		}
 
-		if r.Via != nil {
-			routeTree.AddCIDR(r.Cidr, *r.Via)
+		if r.Via.IsValid() {
+			routeTree.Insert(r.Cidr, r.Via)
 		}
 	}
 	return routeTree, nil
 }
 
-func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	var err error
 
 	r := c.Get("tun.routes")
@@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			MTU:     mtu,
 		}
 
-		_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+		r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
 		if err != nil {
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 		}
 
-		if !ipWithin(network, r.Cidr) {
+		if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
 			return nil, fmt.Errorf(
 				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
 				i+1,
@@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 	return routes, nil
 }
 
-func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	var err error
 
 	r := c.Get("tun.unsafe_routes")
@@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
 		}
 
-		nVia := net.ParseIP(via)
-		if nVia == nil {
-			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
+		viaVpnIp, err := netip.ParseAddr(via)
+		if err != nil {
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
 		}
 
 		rRoute, ok := m["route"]
@@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
 		}
 
-		viaVpnIp := iputil.Ip2VpnIp(nVia)
-
 		install := true
 		rInstall, ok := m["install"]
 		if ok {
@@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 		}
 
 		r := Route{
-			Via:     &viaVpnIp,
+			Via:     viaVpnIp,
 			MTU:     mtu,
 			Metric:  metric,
 			Install: install,
 		}
 
-		_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+		r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
 		if err != nil {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 		}
 
-		if ipWithin(network, r.Cidr) {
+		if network.Contains(r.Cidr.Addr()) {
 			return nil, fmt.Errorf(
 				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
 				i+1,

+ 27 - 16
overlay/route_test.go

@@ -2,11 +2,10 @@ package overlay
 
 import (
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 )
@@ -14,7 +13,8 @@ import (
 func Test_parseRoutes(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	assert.NoError(t, err)
 
 	// test no routes config
 	routes, err := parseRoutes(c, n)
@@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
+	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
@@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) {
 func Test_parseUnsafeRoutes(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	assert.NoError(t, err)
 
 	// test no routes config
 	routes, err := parseUnsafeRoutes(c, n)
@@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
+	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
@@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
+	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
@@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 func Test_makeRouteTree(t *testing.T) {
 	l := test.NewLogger()
 	c := config.NewC(l)
-	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	assert.NoError(t, err)
 
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
@@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) {
 	routeTree, err := makeRouteTree(l, routes, true)
 	assert.NoError(t, err)
 
-	ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
-	ok, r := routeTree.MostSpecificContains(ip)
+	ip, err := netip.ParseAddr("1.0.0.2")
+	assert.NoError(t, err)
+	r, ok := routeTree.Lookup(ip)
 	assert.True(t, ok)
-	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
 
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
-	ok, r = routeTree.MostSpecificContains(ip)
+	nip, err := netip.ParseAddr("192.168.0.1")
+	assert.NoError(t, err)
+	assert.Equal(t, nip, r)
+
+	ip, err = netip.ParseAddr("1.0.0.1")
+	assert.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
 	assert.True(t, ok)
-	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
 
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
-	ok, r = routeTree.MostSpecificContains(ip)
+	nip, err = netip.ParseAddr("192.168.0.2")
+	assert.NoError(t, err)
+	assert.Equal(t, nip, r)
+
+	ip, err = netip.ParseAddr("1.1.0.1")
+	assert.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
 	assert.False(t, ok)
 }

+ 5 - 5
overlay/tun.go

@@ -1,7 +1,7 @@
 package overlay
 
 import (
-	"net"
+	"net/netip"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -11,9 +11,9 @@ import (
 const DefaultMTU = 1300
 
 // TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
 	switch {
 	case c.GetBool("tun.disabled", false):
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
 }
 
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+	return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
 		return newTunFromFd(c, l, *fd, tunCidr)
 	}
 }
 
-func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 		return false, nil, nil
 	}

+ 9 - 10
overlay/tun_android.go

@@ -6,27 +6,26 @@ package overlay
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"sync/atomic"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 )
 
 type tun struct {
 	io.ReadWriteCloser
 	fd        int
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
 	return t, nil
 }
 
-func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
@@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 41 - 18
overlay/tun_darwin.go

@@ -8,15 +8,15 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"net/netip"
 	"os"
 	"sync/atomic"
 	"syscall"
 	"unsafe"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	netroute "golang.org/x/net/route"
 	"golang.org/x/sys/unix"
@@ -25,10 +25,10 @@ import (
 type tun struct {
 	io.ReadWriteCloser
 	Device     string
-	cidr       *net.IPNet
+	cidr       netip.Prefix
 	DefaultMTU int
 	Routes     atomic.Pointer[[]Route]
-	routeTree  atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree  atomic.Pointer[bart.Table[netip.Addr]]
 	linkAddr   *netroute.LinkAddr
 	l          *logrus.Logger
 
@@ -73,7 +73,7 @@ type ifreqMTU struct {
 	pad  [8]byte
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	name := c.GetString("tun.dev", "")
 	ifIndex := -1
 	if name != "" && name != "utun" {
@@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 
@@ -188,8 +188,13 @@ func (t *tun) Activate() error {
 
 	var addr, mask [4]byte
 
-	copy(addr[:], t.cidr.IP.To4())
-	copy(mask[:], t.cidr.Mask)
+	if !t.cidr.Addr().Is4() {
+		//TODO: IPV6-WORK
+		panic("need ipv6")
+	}
+
+	addr = t.cidr.Addr().As4()
+	copy(mask[:], prefixToMask(t.cidr))
 
 	s, err := unix.Socket(
 		unix.AF_INET,
@@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	ok, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, ok := t.routeTree.Load().Lookup(ip)
 	if ok {
 		return r
 	}
-
-	return 0
+	return netip.Addr{}
 }
 
 // Get the LinkAddr for the interface of the given name
@@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error {
 	maskAddr := &netroute.Inet4Addr{}
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		copy(routeAddr.IP[:], r.Cidr.IP.To4())
-		copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+		if !r.Cidr.Addr().Is4() {
+			//TODO: implement ipv6
+			panic("Cant handle ipv6 routes yet")
+		}
+
+		routeAddr.IP = r.Cidr.Addr().As4()
+		//TODO: we could avoid the copy
+		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
 
 		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
 		if err != nil {
@@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		copy(routeAddr.IP[:], r.Cidr.IP.To4())
-		copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+		if r.Cidr.Addr().Is6() {
+			//TODO: implement ipv6
+			panic("Cant handle ipv6 routes yet")
+		}
+
+		routeAddr.IP = r.Cidr.Addr().As4()
+		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
 
 		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
 		if err != nil {
@@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) {
 	return n - 4, err
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 
@@ -547,3 +562,11 @@ func (t *tun) Name() string {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
+
+func prefixToMask(prefix netip.Prefix) []byte {
+	pLen := 128
+	if prefix.Addr().Is4() {
+		pLen = 32
+	}
+	return net.CIDRMask(prefix.Bits(), pLen)
+}

+ 6 - 6
overlay/tun_disabled.go

@@ -3,7 +3,7 @@ package overlay
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"strings"
 
 	"github.com/rcrowley/go-metrics"
@@ -13,7 +13,7 @@ import (
 
 type disabledTun struct {
 	read chan []byte
-	cidr *net.IPNet
+	cidr netip.Prefix
 
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
@@ -21,7 +21,7 @@ type disabledTun struct {
 	l  *logrus.Logger
 }
 
-func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
 		cidr: cidr,
 		read: make(chan []byte, queueLen),
@@ -43,11 +43,11 @@ func (*disabledTun) Activate() error {
 	return nil
 }
 
-func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
-	return 0
+func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
+	return netip.Addr{}
 }
 
-func (t *disabledTun) Cidr() *net.IPNet {
+func (t *disabledTun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 11 - 12
overlay/tun_freebsd.go

@@ -9,7 +9,7 @@ import (
 	"fmt"
 	"io"
 	"io/fs"
-	"net"
+	"net/netip"
 	"os"
 	"os/exec"
 	"strconv"
@@ -17,10 +17,9 @@ import (
 	"syscall"
 	"unsafe"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -48,10 +47,10 @@ type ifreqDestroy struct {
 
 type tun struct {
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -79,11 +78,11 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
 	var file *os.File
 	var err error
@@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error
 func (t *tun) Activate() error {
 	var err error
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 
@@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}

+ 9 - 10
overlay/tun_ios.go

@@ -7,32 +7,31 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"sync"
 	"sync/atomic"
 	"syscall"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 )
 
 type tun struct {
 	io.ReadWriteCloser
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 }
 
-func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	t := &tun{
 		cidr:            cidr,
@@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
@@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error {
 	return tr.f.Close()
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 56 - 35
overlay/tun_linux.go

@@ -4,19 +4,18 @@
 package overlay
 
 import (
-	"bytes"
 	"fmt"
 	"io"
 	"net"
+	"net/netip"
 	"os"
 	"strings"
 	"sync/atomic"
 	"unsafe"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
@@ -26,7 +25,7 @@ type tun struct {
 	io.ReadWriteCloser
 	fd          int
 	Device      string
-	cidr        *net.IPNet
+	cidr        netip.Prefix
 	MaxMTU      int
 	DefaultMTU  int
 	TXQueueLen  int
@@ -34,7 +33,7 @@ type tun struct {
 	ioctlFd     uintptr
 
 	Routes          atomic.Pointer[[]Route]
-	routeTree       atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree       atomic.Pointer[bart.Table[netip.Addr]]
 	routeChan       chan struct{}
 	useSystemRoutes bool
 
@@ -65,7 +64,7 @@ type ifreqQLEN struct {
 	pad   [8]byte
 }
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 	t, err := newTunGeneric(c, l, file, cidr)
@@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
 	return t, nil
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t
 	return t, nil
 }
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
 	t := &tun{
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
@@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return file, nil
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
@@ -275,8 +274,10 @@ func (t *tun) Activate() error {
 
 	var addr, mask [4]byte
 
-	copy(addr[:], t.cidr.IP.To4())
-	copy(mask[:], t.cidr.Mask)
+	//TODO: IPV6-WORK
+	addr = t.cidr.Addr().As4()
+	tmask := net.CIDRMask(t.cidr.Bits(), 32)
+	copy(mask[:], tmask)
 
 	s, err := unix.Socket(
 		unix.AF_INET,
@@ -364,14 +365,19 @@ func (t *tun) setMTU() {
 
 func (t *tun) setDefaultRoute() error {
 	// Default route
-	dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
+
+	dr := &net.IPNet{
+		IP:   t.cidr.Masked().Addr().AsSlice(),
+		Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+	}
+
 	nr := netlink.Route{
 		LinkIndex: t.deviceIndex,
 		Dst:       dr,
 		MTU:       t.DefaultMTU,
 		AdvMSS:    t.advMSS(Route{}),
 		Scope:     unix.RT_SCOPE_LINK,
-		Src:       t.cidr.IP,
+		Src:       net.IP(t.cidr.Addr().AsSlice()),
 		Protocol:  unix.RTPROT_KERNEL,
 		Table:     unix.RT_TABLE_MAIN,
 		Type:      unix.RTN_UNICAST,
@@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 		}
 
+		dr := &net.IPNet{
+			IP:   r.Cidr.Masked().Addr().AsSlice(),
+			Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+		}
+
 		nr := netlink.Route{
 			LinkIndex: t.deviceIndex,
-			Dst:       r.Cidr,
+			Dst:       dr,
 			MTU:       r.MTU,
 			AdvMSS:    t.advMSS(r),
 			Scope:     unix.RT_SCOPE_LINK,
@@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) {
 			continue
 		}
 
+		dr := &net.IPNet{
+			IP:   r.Cidr.Masked().Addr().AsSlice(),
+			Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+		}
+
 		nr := netlink.Route{
 			LinkIndex: t.deviceIndex,
-			Dst:       r.Cidr,
+			Dst:       dr,
 			MTU:       r.MTU,
 			AdvMSS:    t.advMSS(r),
 			Scope:     unix.RT_SCOPE_LINK,
@@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) {
 	}
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 
@@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	if !t.cidr.Contains(r.Gw) {
+	//TODO: IPV6-WORK what if not ok?
+	gwAddr, ok := netip.AddrFromSlice(r.Gw)
+	if !ok {
+		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
+		return
+	}
+
+	gwAddr = gwAddr.Unmap()
+	if !t.cidr.Contains(gwAddr) {
 		// Gateway isn't in our overlay network, ignore
 		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
 		return
@@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	newTree := cidr.NewTree4[iputil.VpnIp]()
-	if r.Type == unix.RTM_NEWROUTE {
-		for _, oldR := range t.routeTree.Load().List() {
-			newTree.AddCIDR(oldR.CIDR, oldR.Value)
-		}
+	dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
+	if !ok {
+		t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
+		return
+	}
+
+	ones, _ := r.Dst.Mask.Size()
+	dst := netip.PrefixFrom(dstAddr, ones)
+
+	newTree := t.routeTree.Load().Clone()
 
+	if r.Type == unix.RTM_NEWROUTE {
 		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
-		newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
+		newTree.Insert(dst, gwAddr)
 
 	} else {
-		gw := iputil.Ip2VpnIp(r.Gw)
-		for _, oldR := range t.routeTree.Load().List() {
-			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
-				// This is the record to delete
-				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
-				continue
-			}
-
-			newTree.AddCIDR(oldR.CIDR, oldR.Value)
-		}
+		newTree.Delete(dst)
+		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 	}
-
 	t.routeTree.Store(newTree)
 }
 

+ 14 - 15
overlay/tun_netbsd.go

@@ -6,7 +6,7 @@ package overlay
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os/exec"
 	"regexp"
@@ -15,10 +15,9 @@ import (
 	"syscall"
 	"unsafe"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 )
 
@@ -29,10 +28,10 @@ type ifreqDestroy struct {
 
 type tun struct {
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -59,13 +58,13 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
 	var file *os.File
 	var err error
@@ -109,13 +108,13 @@ func (t *tun) Activate() error {
 	var err error
 
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 
@@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

+ 14 - 15
overlay/tun_openbsd.go

@@ -6,7 +6,7 @@ package overlay
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os/exec"
 	"regexp"
@@ -14,19 +14,18 @@ import (
 	"sync/atomic"
 	"syscall"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 )
 
 type tun struct {
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -43,13 +42,13 @@ func (t *tun) Close() error {
 	return nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 }
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) Activate() error {
 	var err error
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -139,7 +138,7 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -149,20 +148,20 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 }
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 9 - 10
overlay/tun_tester.go

@@ -6,21 +6,20 @@ package overlay
 import (
 	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"sync/atomic"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type TestTun struct {
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	Routes    []Route
-	routeTree *cidr.Tree4[iputil.VpnIp]
+	routeTree *bart.Table[netip.Addr]
 	l         *logrus.Logger
 
 	closed    atomic.Bool
@@ -28,7 +27,7 @@ type TestTun struct {
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
 	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
 	if err != nil {
 		return nil, err
@@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e
 	}, nil
 }
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 
@@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.MostSpecificContains(ip)
+func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Lookup(ip)
 	return r
 }
 
@@ -96,7 +95,7 @@ func (t *TestTun) Activate() error {
 	return nil
 }
 
-func (t *TestTun) Cidr() *net.IPNet {
+func (t *TestTun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 11 - 11
overlay/tun_water_windows.go

@@ -4,30 +4,30 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"net/netip"
 	"os/exec"
 	"strconv"
 	"sync/atomic"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/songgao/water"
 )
 
 type waterTun struct {
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	f         *net.Interface
 	*water.Interface
 }
 
-func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
+func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	t := &waterTun{
 		cidr: cidr,
@@ -70,8 +70,8 @@ func (t *waterTun) Activate() error {
 		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
 		fmt.Sprintf("name=%s", t.Device),
 		"source=static",
-		fmt.Sprintf("addr=%s", t.cidr.IP),
-		fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
+		fmt.Sprintf("addr=%s", t.cidr.Addr()),
+		fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
 		"gateway=none",
 	).Run()
 	if err != nil {
@@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error {
 	// Path routes
 	routes := *t.Routes.Load()
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
@@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) {
 	}
 }
 
-func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
-func (t *waterTun) Cidr() *net.IPNet {
+func (t *waterTun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 3 - 3
overlay/tun_windows.go

@@ -5,7 +5,7 @@ package overlay
 
 import (
 	"fmt"
-	"net"
+	"net/netip"
 	"os"
 	"path/filepath"
 	"runtime"
@@ -15,11 +15,11 @@ import (
 	"github.com/slackhq/nebula/config"
 )
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
 	useWintun := true
 	if err := checkWinTunExists(); err != nil {
 		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

+ 12 - 38
overlay/tun_wintun_windows.go

@@ -4,15 +4,13 @@ import (
 	"crypto"
 	"fmt"
 	"io"
-	"net"
 	"net/netip"
 	"sync/atomic"
 	"unsafe"
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/wintun"
 	"golang.org/x/sys/windows"
@@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
 
 type winTun struct {
 	Device    string
-	cidr      *net.IPNet
-	prefix    netip.Prefix
+	cidr      netip.Prefix
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 
 	tun *wintun.NativeTun
@@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
 	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 }
 
-func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
+func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	guid, err := generateGUIDByDeviceName(deviceName)
 	if err != nil {
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
 	}
 
-	prefix, err := iputil.ToNetIpPrefix(*cidr)
-	if err != nil {
-		return nil, err
-	}
-
 	t := &winTun{
 		Device: deviceName,
 		cidr:   cidr,
-		prefix: prefix,
 		MTU:    c.GetInt("tun.mtu", DefaultMTU),
 		l:      l,
 	}
@@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error {
 func (t *winTun) Activate() error {
 	luid := winipcfg.LUID(t.tun.LUID())
 
-	err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
+	err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
 	if err != nil {
 		return fmt.Errorf("failed to set address: %w", err)
 	}
@@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error {
 	foundDefault4 := false
 
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			continue
 		}
 
-		prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
-		if err != nil {
-			retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
-			if logErrors {
-				retErr.Log(t.l)
-				continue
-			} else {
-				return retErr
-			}
-		}
-
 		// Add our unsafe route
-		err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
+		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
 		if err != nil {
 			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
 			if logErrors {
@@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
 		}
 
 		if !foundDefault4 {
-			if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
+			if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
 				foundDefault4 = true
 			}
 		}
@@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
 			continue
 		}
 
-		prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
-		if err != nil {
-			t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
-			continue
-		}
-
-		err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
+		err := luid.DeleteRoute(r.Cidr, r.Via)
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
@@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error {
 	return nil
 }
 
-func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 }
 
-func (t *winTun) Cidr() *net.IPNet {
+func (t *winTun) Cidr() netip.Prefix {
 	return t.cidr
 }
 

+ 7 - 8
overlay/user.go

@@ -2,18 +2,17 @@ package overlay
 
 import (
 	"io"
-	"net"
+	"net/netip"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 
-func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
 	return NewUserDevice(tunCidr)
 }
 
-func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
 	// these pipes guarantee each write/read will match 1:1
 	or, ow := io.Pipe()
 	ir, iw := io.Pipe()
@@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
 }
 
 type UserDevice struct {
-	tunCidr *net.IPNet
+	tunCidr netip.Prefix
 
 	outboundReader *io.PipeReader
 	outboundWriter *io.PipeWriter
@@ -39,9 +38,9 @@ type UserDevice struct {
 func (d *UserDevice) Activate() error {
 	return nil
 }
-func (d *UserDevice) Cidr() *net.IPNet                      { return d.tunCidr }
-func (d *UserDevice) Name() string                          { return "faketun0" }
-func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) Cidr() netip.Prefix                { return d.tunCidr }
+func (d *UserDevice) Name() string                      { return "faketun0" }
+func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return d, nil
 }

+ 2 - 0
pki.go

@@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
 	}
 
 	if !initial {
+		//TODO: include check for mask equality as well
+
 		// did IP in cert change? if so, don't set
 		currentCert := p.cs.Load().Certificate
 		oldIPs := currentCert.Details.Ips

+ 52 - 31
relay_manager.go

@@ -2,14 +2,15 @@ package nebula
 
 import (
 	"context"
+	"encoding/binary"
 	"errors"
 	"fmt"
+	"net/netip"
 	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 )
 
 type relayManager struct {
@@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) {
 
 // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
 // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
-func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) {
+func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
 	hm.Lock()
 	defer hm.Unlock()
 	for i := 0; i < 32; i++ {
@@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter
 
 func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
 	rm.l.WithFields(logrus.Fields{
-		"relayFrom":           iputil.VpnIp(m.RelayFromIp),
-		"relayTo":             iputil.VpnIp(m.RelayToIp),
+		"relayFrom":           m.RelayFromIp,
+		"relayTo":             m.RelayToIp,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
 		"vpnIp":               h.vpnIp}).
 		Info("handleCreateRelayResponse")
-	target := iputil.VpnIp(m.RelayToIp)
+	target := m.RelayToIp
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+	targetAddr := netip.AddrFrom4(b)
 
 	relay, err := rm.EstablishRelay(h, m)
 	if err != nil {
@@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
 		return
 	}
-	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
+	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
 	if !ok {
 		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
 		return
 	}
 	if peerRelay.State == PeerRequested {
+		//TODO: IPV6-WORK
+		b = peerHostInfo.vpnIp.As4()
 		peerRelay.State = Established
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: peerRelay.LocalIndex,
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
-			RelayFromIp:         uint32(peerHostInfo.vpnIp),
+			RelayFromIp:         binary.BigEndian.Uint32(b[:]),
 			RelayToIp:           uint32(target),
 		}
 		msg, err := resp.Marshal()
@@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-				"relayTo":             iputil.VpnIp(resp.RelayToIp),
+				"relayFrom":           resp.RelayFromIp,
+				"relayTo":             resp.RelayToIp,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"vpnIp":               peerHostInfo.vpnIp}).
@@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 }
 
 func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
+	from := netip.AddrFrom4(b)
 
-	from := iputil.VpnIp(m.RelayFromIp)
-	target := iputil.VpnIp(m.RelayToIp)
+	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+	target := netip.AddrFrom4(b)
 
 	logMsg := rm.l.WithFields(logrus.Fields{
 		"relayFrom":           from,
@@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 	logMsg.Info("handleCreateRelayRequest")
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// an issue migrating relays over to newly re-handshaked host info objects.
-	if from == f.myVpnIp {
-		logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
+	if from == f.myVpnNet.Addr() {
+		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		return
 	}
 	// Is the target of the relay me?
-	if target == f.myVpnIp {
+	if target == f.myVpnNet.Addr() {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		if ok {
 			switch existingRelay.State {
@@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			return
 		}
 
+		//TODO: IPV6-WORK
+		fromB := from.As4()
+		targetB := target.As4()
+
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: relay.LocalIndex,
 			InitiatorRelayIndex: relay.RemoteIndex,
-			RelayFromIp:         uint32(from),
-			RelayToIp:           uint32(target),
+			RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
+			RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 		}
 		msg, err := resp.Marshal()
 		if err != nil {
@@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-				"relayTo":             iputil.VpnIp(resp.RelayToIp),
+				//TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
+				"relayFrom":           from,
+				"relayTo":             target,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"vpnIp":               h.vpnIp}).
@@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			f.Handshake(target)
 			return
 		}
-		if peer.remote == nil {
+		if !peer.remote.IsValid() {
 			// Only create relays to peers for whom I have a direct connection
 			return
 		}
@@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			sendCreateRequest = true
 		}
 		if sendCreateRequest {
+			//TODO: IPV6-WORK
+			fromB := h.vpnIp.As4()
+			targetB := target.As4()
+
 			// Send a CreateRelayRequest to the peer.
 			req := NebulaControl{
 				Type:                NebulaControl_CreateRelayRequest,
 				InitiatorRelayIndex: index,
-				RelayFromIp:         uint32(h.vpnIp),
-				RelayToIp:           uint32(target),
+				RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
+				RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 			}
 			msg, err := req.Marshal()
 			if err != nil {
@@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			} else {
 				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				rm.l.WithFields(logrus.Fields{
-					"relayFrom":           iputil.VpnIp(req.RelayFromIp),
-					"relayTo":             iputil.VpnIp(req.RelayToIp),
+					//TODO: IPV6-WORK another lazy used to use the req object
+					"relayFrom":           h.vpnIp,
+					"relayTo":             target,
 					"initiatorRelayIndex": req.InitiatorRelayIndex,
 					"responderRelayIndex": req.ResponderRelayIndex,
 					"vpnIp":               target}).
@@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 						"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
 					return
 				}
+				//TODO: IPV6-WORK
+				fromB := h.vpnIp.As4()
+				targetB := target.As4()
 				resp := NebulaControl{
 					Type:                NebulaControl_CreateRelayResponse,
 					ResponderRelayIndex: relay.LocalIndex,
 					InitiatorRelayIndex: relay.RemoteIndex,
-					RelayFromIp:         uint32(h.vpnIp),
-					RelayToIp:           uint32(target),
+					RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
+					RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 				}
 				msg, err := resp.Marshal()
 				if err != nil {
@@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 				} else {
 					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					rm.l.WithFields(logrus.Fields{
-						"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-						"relayTo":             iputil.VpnIp(resp.RelayToIp),
+						//TODO: IPV6-WORK more lazy, used to use resp object
+						"relayFrom":           h.vpnIp,
+						"relayTo":             target,
 						"initiatorRelayIndex": resp.InitiatorRelayIndex,
 						"responderRelayIndex": resp.ResponderRelayIndex,
 						"vpnIp":               h.vpnIp}).
@@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		}
 	}
 }
-
-func (rm *relayManager) RemoveRelay(localIdx uint32) {
-	rm.hostmap.RemoveRelay(localIdx)
-}

+ 73 - 93
remote_list.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"bytes"
 	"context"
 	"net"
 	"net/netip"
@@ -12,16 +11,14 @@ import (
 	"time"
 
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 
 // forEachFunc is used to benefit folks that want to do work inside the lock
-type forEachFunc func(addr *udp.Addr, preferred bool)
+type forEachFunc func(addr netip.AddrPort, preferred bool)
 
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
+type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
 
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // The string key is the owners vpnIp
@@ -30,9 +27,9 @@ type CacheMap map[string]*Cache
 // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
 // We don't reason about ipv4 vs ipv6 here
 type Cache struct {
-	Learned  []*udp.Addr `json:"learned,omitempty"`
-	Reported []*udp.Addr `json:"reported,omitempty"`
-	Relay    []*net.IP   `json:"relay"`
+	Learned  []netip.AddrPort `json:"learned,omitempty"`
+	Reported []netip.AddrPort `json:"reported,omitempty"`
+	Relay    []netip.Addr     `json:"relay"`
 }
 
 //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -46,7 +43,7 @@ type cache struct {
 }
 
 type cacheRelay struct {
-	relay []uint32
+	relay []netip.Addr
 }
 
 // cacheV4 stores learned and reported ipv4 records under cache
@@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 						continue
 					}
 					for _, a := range addrs {
-						netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+						netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{}
 					}
 				}
 				origSet := r.ips.Load()
@@ -193,22 +190,22 @@ type RemoteList struct {
 	sync.RWMutex
 
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
-	addrs []*udp.Addr
+	addrs []netip.AddrPort
 
 	// A set of relay addresses. VpnIp addresses that the remote identified as relays.
-	relays []*iputil.VpnIp
+	relays []netip.Addr
 
 	// These are maps to store v4 and v6 addresses per lighthouse
 	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
 	// For learned addresses, this is the vpnIp that sent the packet
-	cache map[iputil.VpnIp]*cache
+	cache map[netip.Addr]*cache
 
 	hr        *hostnamesResults
 	shouldAdd func(netip.Addr) bool
 
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
-	badRemotes []*udp.Addr
+	badRemotes []netip.AddrPort
 
 	// A flag that the cache may have changed and addrs needs to be rebuilt
 	shouldRebuild bool
@@ -217,9 +214,9 @@ type RemoteList struct {
 // NewRemoteList creates a new empty RemoteList
 func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
 	return &RemoteList{
-		addrs:     make([]*udp.Addr, 0),
-		relays:    make([]*iputil.VpnIp, 0),
-		cache:     make(map[iputil.VpnIp]*cache),
+		addrs:     make([]netip.AddrPort, 0),
+		relays:    make([]netip.Addr, 0),
+		cache:     make(map[netip.Addr]*cache),
 		shouldAdd: shouldAdd,
 	}
 }
@@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
 
 // Len locks and reports the size of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
+func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
 	r.Rebuild(preferredRanges)
 	r.RLock()
 	defer r.RUnlock()
@@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
 
 // ForEach locks and will call the forEachFunc for every deduplicated address in the list
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
+func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) {
 	r.Rebuild(preferredRanges)
 	r.RLock()
 	for _, v := range r.addrs {
-		forEach(v, isPreferred(v.IP, preferredRanges))
+		forEach(v, isPreferred(v.Addr(), preferredRanges))
 	}
 	r.RUnlock()
 }
 
 // CopyAddrs locks and makes a deep copy of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
+func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort {
 	if r == nil {
 		return nil
 	}
@@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
 
 	r.RLock()
 	defer r.RUnlock()
-	c := make([]*udp.Addr, len(r.addrs))
+	c := make([]netip.AddrPort, len(r.addrs))
 	for i, v := range r.addrs {
-		c[i] = v.Copy()
+		c[i] = v
 	}
 	return c
 }
@@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
 // TODO: this needs to support the allow list list
-func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
+func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
 	r.Lock()
 	defer r.Unlock()
-	if v4 := addr.IP.To4(); v4 != nil {
-		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
+	if remote.Addr().Is4() {
+		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
 	} else {
-		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
+		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
 	}
 }
 
@@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
 		c := cm[vpnIp]
 		if c == nil {
 			c = &Cache{
-				Learned:  make([]*udp.Addr, 0),
-				Reported: make([]*udp.Addr, 0),
-				Relay:    make([]*net.IP, 0),
+				Learned:  make([]netip.AddrPort, 0),
+				Reported: make([]netip.AddrPort, 0),
+				Relay:    make([]netip.Addr, 0),
 			}
 			cm[vpnIp] = c
 		}
@@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap {
 
 		if mc.v4 != nil {
 			if mc.v4.learned != nil {
-				c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
+				c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
 			}
 
 			for _, a := range mc.v4.reported {
-				c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
+				c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
 			}
 		}
 
 		if mc.v6 != nil {
 			if mc.v6.learned != nil {
-				c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
+				c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
 			}
 
 			for _, a := range mc.v6.reported {
-				c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
+				c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
 			}
 		}
 
 		if mc.relay != nil {
 			for _, a := range mc.relay.relay {
-				nip := iputil.VpnIp(a).ToIP()
-				c.Relay = append(c.Relay, &nip)
+				c.Relay = append(c.Relay, a)
 			}
 		}
 	}
@@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
 }
 
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad *udp.Addr) {
-	if bad == nil {
+func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
+	if !bad.IsValid() {
 		// relays can have nil udp Addrs
 		return
 	}
@@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) {
 	}
 
 	// We copy here because we are taking something else's memory and we can't trust everything
-	r.badRemotes = append(r.badRemotes, bad.Copy())
+	r.badRemotes = append(r.badRemotes, bad)
 
 	// Mark the next interaction must recollect/dedupe
 	r.shouldRebuild = true
 }
 
 // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
-func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
+func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
 	r.RLock()
 	defer r.RUnlock()
 
-	c := make([]*udp.Addr, len(r.badRemotes))
+	c := make([]netip.AddrPort, len(r.badRemotes))
 	for i, v := range r.badRemotes {
-		c[i] = v.Copy()
+		c[i] = v
 	}
 	return c
 }
@@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() {
 
 // Rebuild locks and generates the deduplicated address list only if there is work to be done
 // There is generally no reason to call this directly but it is safe to do so
-func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
+func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
 	r.Lock()
 	defer r.Unlock()
 
@@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
 }
 
 // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
-func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
+func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
 	for _, v := range r.badRemotes {
-		if v.Equals(remote) {
+		if v == remote {
 			return true
 		}
 	}
@@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
 
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 }
 
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
@@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
 	}
 }
 
-func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) {
+func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeRelay(ownerVpnIp)
 
@@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI
 
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
@@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort)
 
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 }
 
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
@@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
 
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
@@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort)
 	}
 }
 
-func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay {
+func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay {
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 		am = &cache{}
@@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay
 
 // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 {
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 		am = &cache{}
@@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
 
 // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 {
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 		am = &cache{}
@@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() {
 	for _, c := range r.cache {
 		if c.v4 != nil {
 			if c.v4.learned != nil {
-				u := NewUDPAddrFromLH4(c.v4.learned)
+				u := AddrPortFromIp4AndPort(c.v4.learned)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
 			}
 
 			for _, v := range c.v4.reported {
-				u := NewUDPAddrFromLH4(v)
+				u := AddrPortFromIp4AndPort(v)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
@@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() {
 
 		if c.v6 != nil {
 			if c.v6.learned != nil {
-				u := NewUDPAddrFromLH6(c.v6.learned)
+				u := AddrPortFromIp6AndPort(c.v6.learned)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
 			}
 
 			for _, v := range c.v6.reported {
-				u := NewUDPAddrFromLH6(v)
+				u := AddrPortFromIp6AndPort(v)
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 				}
@@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() {
 
 		if c.relay != nil {
 			for _, v := range c.relay.relay {
-				ip := iputil.VpnIp(v)
-				relays = append(relays, &ip)
+				relays = append(relays, v)
 			}
 		}
 	}
@@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() {
 	dnsAddrs := r.hr.GetIPs()
 	for _, addr := range dnsAddrs {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
-			v6 := addr.Addr().As16()
-			addrs = append(addrs, &udp.Addr{
-				IP:   v6[:],
-				Port: addr.Port(),
-			})
+			addrs = append(addrs, addr)
 		}
 	}
 
@@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() {
 }
 
 // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
-func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
+func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
 	n := len(r.addrs)
 	if n < 2 {
 		return
@@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 		b := r.addrs[j]
 		// Preferred addresses first
 
-		aPref := isPreferred(a.IP, preferredRanges)
-		bPref := isPreferred(b.IP, preferredRanges)
+		aPref := isPreferred(a.Addr(), preferredRanges)
+		bPref := isPreferred(b.Addr(), preferredRanges)
 		switch {
 		case aPref && !bPref:
 			// If i is preferred and j is not, i is less than j
@@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 		}
 
 		// ipv6 addresses 2nd
-		a4 := a.IP.To4()
-		b4 := b.IP.To4()
+		a4 := a.Addr().Is4()
+		b4 := b.Addr().Is4()
 		switch {
-		case a4 == nil && b4 != nil:
+		case a4 == false && b4 == true:
 			// If i is v6 and j is v4, i is less than j
 			return true
 
-		case a4 != nil && b4 == nil:
+		case a4 == true && b4 == false:
 			// If j is v6 and i is v4, i is not less than j
 			return false
 
-		case a4 != nil && b4 != nil:
-			// Special case for ipv4, a4 and b4 are not nil
-			aPrivate := isPrivateIP(a4)
-			bPrivate := isPrivateIP(b4)
+		case a4 == true && b4 == true:
+			// i and j are both ipv4
+			aPrivate := a.Addr().IsPrivate()
+			bPrivate := b.Addr().IsPrivate()
 			switch {
 			case !aPrivate && bPrivate:
 				// If i is a public ip (not private) and j is a private ip, i is less then j
@@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 		}
 
 		// lexical order of ips 3rd
-		c := bytes.Compare(a.IP, b.IP)
+		c := a.Addr().Compare(b.Addr())
 		if c == 0 {
 			// Ips are the same, Lexical order of ports 4th
-			return a.Port < b.Port
+			return a.Port() < b.Port()
 		}
 
 		// Ip wasn't the same
@@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 	// Deduplicate
 	a, b := 0, 1
 	for b < n {
-		if !r.addrs[a].Equals(r.addrs[b]) {
+		if r.addrs[a] != r.addrs[b] {
 			a++
 			if a != b {
 				r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
@@ -693,7 +684,7 @@ func minInt(a, b int) int {
 }
 
 // isPreferred returns true of the ip is contained in the preferredRanges list
-func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
+func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
 	//TODO: this would be better in a CIDR6Tree
 	for _, p := range preferredRanges {
 		if p.Contains(ip) {
@@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
 	}
 	return false
 }
-
-var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
-var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
-var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
-
-// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
-func isPrivateIP(ip net.IP) bool {
-	//TODO: another great cidrtree option
-	//TODO: Private for ipv6 or just let it ride?
-	return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
-}

+ 98 - 89
remote_list_test.go

@@ -1,47 +1,47 @@
 package nebula
 
 import (
-	"net"
+	"encoding/binary"
+	"net/netip"
 	"testing"
 
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 )
 
 func TestRemoteList_Rebuild(t *testing.T) {
 	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip4AndPort{
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
+			newIp4AndPortFromString("172.17.0.182:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
+			newIp4AndPortFromString("172.18.0.1:10101"), // this is duped
+			newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe
+			newIp4AndPortFromString("172.19.0.1:10101"),
+			newIp4AndPortFromString("172.31.0.1:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
+			newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
-		1,
-		1,
+		netip.MustParseAddr("0.0.0.1"),
+		netip.MustParseAddr("0.0.0.1"),
 		[]*Ip6AndPort{
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
-			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
+			newIp6AndPortFromString("[1::1]:1"), // this is duped
+			newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
+			newIp6AndPortFromString("[1:100::1]:1"),
+			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
+			newIp6AndPortFromString("[1::1]:2"), // this is a dupe
 		},
-		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *Ip6AndPort) bool { return true },
 	)
 
-	rl.Rebuild([]*net.IPNet{})
+	rl.Rebuild([]netip.Prefix{})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 
 	// ipv6 first, sorted lexically within
@@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
 
 	// Now ensure we can hoist ipv4 up
-	_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
-	assert.NoError(t, err)
-	rl.Rebuild([]*net.IPNet{ipNet})
+	rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 
 	// ipv4 first, public then private, lexically within them
@@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 	assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
 
 	// Ensure we can hoist a specific ipv4 range over anything else
-	_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
-	assert.NoError(t, err)
-	rl.Rebuild([]*net.IPNet{ipNet})
+	rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 
 	// Preferred ipv4 first
@@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) {
 func BenchmarkFullRebuild(b *testing.B) {
 	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip4AndPort{
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+			newIp4AndPortFromString("70.199.182.92:1475"),
+			newIp4AndPortFromString("172.17.0.182:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),
+			newIp4AndPortFromString("172.18.0.1:10101"),
+			newIp4AndPortFromString("172.19.0.1:10101"),
+			newIp4AndPortFromString("172.31.0.1:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip6AndPort{
-			NewIp6AndPort(net.ParseIP("1::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
-			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+			newIp6AndPortFromString("[1::1]:1"),
+			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+			newIp6AndPortFromString("[1:100::1]:1"),
+			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
-		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *Ip6AndPort) bool { return true },
 	)
 
 	b.Run("no preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{})
+			rl.Rebuild([]netip.Prefix{})
 		}
 	})
 
-	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
-	assert.NoError(b, err)
+	ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
 	b.Run("1 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{ipNet})
+			rl.Rebuild([]netip.Prefix{ipNet1})
 		}
 	})
 
-	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
-	assert.NoError(b, err)
+	ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
 	b.Run("2 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+			rl.Rebuild([]netip.Prefix{ipNet2})
 		}
 	})
 
-	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
-	assert.NoError(b, err)
+	ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
 	b.Run("3 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+			rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
 		}
 	})
 }
@@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) {
 func BenchmarkSortRebuild(b *testing.B) {
 	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip4AndPort{
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+			newIp4AndPortFromString("70.199.182.92:1475"),
+			newIp4AndPortFromString("172.17.0.182:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),
+			newIp4AndPortFromString("172.18.0.1:10101"),
+			newIp4AndPortFromString("172.19.0.1:10101"),
+			newIp4AndPortFromString("172.31.0.1:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 
 	rl.unlockedSetV6(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip6AndPort{
-			NewIp6AndPort(net.ParseIP("1::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
-			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+			newIp6AndPortFromString("[1::1]:1"),
+			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+			newIp6AndPortFromString("[1:100::1]:1"),
+			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
-		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *Ip6AndPort) bool { return true },
 	)
 
 	b.Run("no preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{})
+			rl.Rebuild([]netip.Prefix{})
 		}
 	})
 
-	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
-	rl.Rebuild([]*net.IPNet{ipNet})
+	ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
+	rl.Rebuild([]netip.Prefix{ipNet1})
 
-	assert.NoError(b, err)
 	b.Run("1 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
-			rl.Rebuild([]*net.IPNet{ipNet})
+			rl.Rebuild([]netip.Prefix{ipNet1})
 		}
 	})
 
-	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
-	rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+	ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
+	rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
 
-	assert.NoError(b, err)
 	b.Run("2 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+			rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
 		}
 	})
 
-	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
-	rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+	ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
+	rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
 
-	assert.NoError(b, err)
 	b.Run("3 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+			rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
 		}
 	})
 }
+
+func newIp4AndPortFromString(s string) *Ip4AndPort {
+	a := netip.MustParseAddrPort(s)
+	v4Addr := a.Addr().As4()
+	return &Ip4AndPort{
+		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+		Port: uint32(a.Port()),
+	}
+}
+
+func newIp6AndPortFromString(s string) *Ip6AndPort {
+	a := netip.MustParseAddrPort(s)
+	v6Addr := a.Addr().As16()
+	return &Ip6AndPort{
+		Hi:   binary.BigEndian.Uint64(v6Addr[:8]),
+		Lo:   binary.BigEndian.Uint64(v6Addr[8:]),
+		Port: uint32(a.Port()),
+	}
+}

+ 1 - 1
service/service.go

@@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) {
 
 	ipNet := device.Cidr()
 	pa := tcpip.ProtocolAddress{
-		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
+		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
 		Protocol:          ipv4.ProtocolNumber,
 	}
 	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

+ 6 - 10
service/service_test.go

@@ -4,7 +4,7 @@ import (
 	"bytes"
 	"context"
 	"errors"
-	"net"
+	"net/netip"
 	"testing"
 	"time"
 
@@ -18,12 +18,8 @@ import (
 
 type m map[string]interface{}
 
-func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
-
-	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
-	copy(vpnIpNet.IP, udpIp)
-
-	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
+	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
 	caB, err := caCrt.MarshalToPEM()
 	if err != nil {
 		panic(err)
@@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string,
 }
 
 func TestService(t *testing.T) {
-	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
 		"static_host_map": m{},
 		"lighthouse": m{
 			"am_lighthouse": true,
@@ -94,7 +90,7 @@ func TestService(t *testing.T) {
 			"port": 4243,
 		},
 	})
-	b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+	b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
 		"static_host_map": m{
 			"10.0.0.1": []string{"localhost:4243"},
 		},

+ 29 - 36
ssh.go

@@ -7,6 +7,7 @@ import (
 	"flag"
 	"fmt"
 	"net"
+	"net/netip"
 	"os"
 	"reflect"
 	"runtime"
@@ -18,9 +19,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
-	"github.com/slackhq/nebula/udp"
 )
 
 type sshListHostMapFlags struct {
@@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 	}
 
 	sort.Slice(hm, func(i, j int) bool {
-		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
+		return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
 	})
 
 	if fs.Json || fs.Pretty {
@@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 		return w.WriteLine("No vpn ip was provided")
 	}
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
@@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine("No vpn ip was provided")
 	}
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
@@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("No vpn ip was provided")
 	}
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
@@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
 
-	var addr *udp.Addr
+	var addr netip.AddrPort
 	if flags.Address != "" {
-		addr = udp.NewAddrFromString(flags.Address)
-		if addr == nil {
+		addr, err = netip.ParseAddrPort(flags.Address)
+		if err != nil {
 			return w.WriteLine("Address could not be parsed")
 		}
 	}
 
 	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
-	if addr != nil {
+	if addr.IsValid() {
 		hostInfo.SetRemote(addr)
 	}
 
@@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("No address was provided")
 	}
 
-	addr := udp.NewAddrFromString(flags.Address)
-	if addr == nil {
+	addr, err := netip.ParseAddrPort(flags.Address)
+	if err != nil {
 		return w.WriteLine("Address could not be parsed")
 	}
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
@@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 
 	cert := ifce.pki.GetCertState().Certificate
 	if len(a) > 0 {
-		parsedIp := net.ParseIP(a[0])
-		if parsedIp == nil {
+		vpnIp, err := netip.ParseAddr(a[0])
+		if err != nil {
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 
-		vpnIp := iputil.Ip2VpnIp(parsedIp)
-		if vpnIp == 0 {
+		if !vpnIp.IsValid() {
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 
@@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		Error          error
 		Type           string
 		State          string
-		PeerIp         iputil.VpnIp
+		PeerIp         netip.Addr
 		LocalIndex     uint32
 		RemoteIndex    uint32
-		RelayedThrough []iputil.VpnIp
+		RelayedThrough []netip.Addr
 	}
 
 	type RelayOutput struct {
-		NebulaIp    iputil.VpnIp
+		NebulaIp    netip.Addr
 		RelayForIps []RelayFor
 	}
 
@@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine("No vpn ip was provided")
 	}
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 

+ 5 - 7
test/tun.go

@@ -3,23 +3,21 @@ package test
 import (
 	"errors"
 	"io"
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 
 type NoopTun struct{}
 
-func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
-	return 0
+func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
+	return netip.Addr{}
 }
 
 func (NoopTun) Activate() error {
 	return nil
 }
 
-func (NoopTun) Cidr() *net.IPNet {
-	return nil
+func (NoopTun) Cidr() netip.Prefix {
+	return netip.Prefix{}
 }
 
 func (NoopTun) Name() string {

+ 5 - 4
timeout_test.go

@@ -1,6 +1,7 @@
 package nebula
 
 import (
+	"net/netip"
 	"testing"
 	"time"
 
@@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
 	assert.Equal(t, 0, tw.current)
 
 	fps := []firewall.Packet{
-		{LocalIP: 1},
-		{LocalIP: 2},
-		{LocalIP: 3},
-		{LocalIP: 4},
+		{LocalIP: netip.MustParseAddr("0.0.0.1")},
+		{LocalIP: netip.MustParseAddr("0.0.0.2")},
+		{LocalIP: netip.MustParseAddr("0.0.0.3")},
+		{LocalIP: netip.MustParseAddr("0.0.0.4")},
 	}
 
 	tw.Add(fps[0], time.Second*1)

+ 8 - 6
udp/conn.go

@@ -1,6 +1,8 @@
 package udp
 
 import (
+	"net/netip"
+
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
@@ -9,7 +11,7 @@ import (
 const MTU = 9001
 
 type EncReader func(
-	addr *Addr,
+	addr netip.AddrPort,
 	out []byte,
 	packet []byte,
 	header *header.H,
@@ -22,9 +24,9 @@ type EncReader func(
 
 type Conn interface {
 	Rebind() error
-	LocalAddr() (*Addr, error)
+	LocalAddr() (netip.AddrPort, error)
 	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
-	WriteTo(b []byte, addr *Addr) error
+	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
 	Close() error
 }
@@ -34,13 +36,13 @@ type NoopConn struct{}
 func (NoopConn) Rebind() error {
 	return nil
 }
-func (NoopConn) LocalAddr() (*Addr, error) {
-	return nil, nil
+func (NoopConn) LocalAddr() (netip.AddrPort, error) {
+	return netip.AddrPort{}, nil
 }
 func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
 	return
 }
-func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
+func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 	return nil
 }
 func (NoopConn) ReloadConfig(_ *config.C) {

+ 3 - 2
udp/temp.go

@@ -1,9 +1,10 @@
 package udp
 
 import (
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 
 //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
 
-type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
+// TODO: IPV6-WORK this can likely be removed now
+type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

+ 0 - 100
udp/udp_all.go

@@ -1,100 +0,0 @@
-package udp
-
-import (
-	"encoding/json"
-	"fmt"
-	"net"
-	"strconv"
-)
-
-type m map[string]interface{}
-
-type Addr struct {
-	IP   net.IP
-	Port uint16
-}
-
-func NewAddr(ip net.IP, port uint16) *Addr {
-	addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
-	copy(addr.IP, ip.To16())
-	return &addr
-}
-
-func NewAddrFromString(s string) *Addr {
-	ip, port, err := ParseIPAndPort(s)
-	//TODO: handle err
-	_ = err
-	return &Addr{IP: ip.To16(), Port: port}
-}
-
-func (ua *Addr) Equals(t *Addr) bool {
-	if t == nil || ua == nil {
-		return t == nil && ua == nil
-	}
-	return ua.IP.Equal(t.IP) && ua.Port == t.Port
-}
-
-func (ua *Addr) String() string {
-	if ua == nil {
-		return "<nil>"
-	}
-
-	return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
-}
-
-func (ua *Addr) MarshalJSON() ([]byte, error) {
-	if ua == nil {
-		return nil, nil
-	}
-
-	return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
-}
-
-func (ua *Addr) Copy() *Addr {
-	if ua == nil {
-		return nil
-	}
-
-	nu := Addr{
-		Port: ua.Port,
-		IP:   make(net.IP, len(ua.IP)),
-	}
-
-	copy(nu.IP, ua.IP)
-	return &nu
-}
-
-type AddrSlice []*Addr
-
-func (a AddrSlice) Equal(b AddrSlice) bool {
-	if len(a) != len(b) {
-		return false
-	}
-
-	for i := range a {
-		if !a[i].Equals(b[i]) {
-			return false
-		}
-	}
-
-	return true
-}
-
-func ParseIPAndPort(s string) (net.IP, uint16, error) {
-	rIp, sPort, err := net.SplitHostPort(s)
-	if err != nil {
-		return nil, 0, err
-	}
-
-	addr, err := net.ResolveIPAddr("ip", rIp)
-	if err != nil {
-		return nil, 0, err
-	}
-
-	iPort, err := strconv.Atoi(sPort)
-	if err != nil {
-		return nil, 0, err
-	}
-
-	return addr.IP, uint16(iPort), nil
-}

+ 2 - 1
udp/udp_android.go

@@ -6,13 +6,14 @@ package udp
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 

+ 2 - 1
udp/udp_bsd.go

@@ -9,13 +9,14 @@ package udp
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 

+ 2 - 1
udp/udp_darwin.go

@@ -8,13 +8,14 @@ package udp
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 

+ 23 - 14
udp/udp_generic.go

@@ -11,6 +11,7 @@ import (
 	"context"
 	"fmt"
 	"net"
+	"net/netip"
 
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
@@ -25,7 +26,7 @@ type GenericConn struct {
 
 var _ Conn = &GenericConn{}
 
-func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
 	if err != nil {
@@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 }
 
-func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
-	_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
+func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
+	_, err := u.UDPConn.WriteToUDPAddrPort(b, addr)
 	return err
 }
 
-func (u *GenericConn) LocalAddr() (*Addr, error) {
+func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
 	a := u.UDPConn.LocalAddr()
 
 	switch v := a.(type) {
 	case *net.UDPAddr:
-		addr := &Addr{IP: make([]byte, len(v.IP))}
-		copy(addr.IP, v.IP)
-		addr.Port = uint16(v.Port)
-		return addr, nil
+		addr, ok := netip.AddrFromSlice(v.IP)
+		if !ok {
+			return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
+		}
+		return netip.AddrPortFrom(addr, uint16(v.Port)), nil
 
 	default:
-		return nil, fmt.Errorf("LocalAddr returned: %#v", a)
+		return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
 	}
 }
 
@@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
 	buffer := make([]byte, MTU)
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
-	udpAddr := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 
 	for {
 		// Just read one packet at a time
-		n, rua, err := u.ReadFromUDP(buffer)
+		n, rua, err := u.ReadFromUDPAddrPort(buffer)
 		if err != nil {
 			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
 			return
 		}
 
-		udpAddr.IP = rua.IP
-		udpAddr.Port = uint16(rua.Port)
-		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(
+			netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
+			plaintext[:0],
+			buffer[:n],
+			h,
+			fwPacket,
+			lhf,
+			nb,
+			q,
+			cache.Get(u.l),
+		)
 	}
 }

+ 43 - 32
udp/udp_linux.go

@@ -7,6 +7,7 @@ import (
 	"encoding/binary"
 	"fmt"
 	"net"
+	"net/netip"
 	"syscall"
 	"unsafe"
 
@@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
 	return ip, false
 }
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
-	ipV4, isV4 := maybeIPV4(ip)
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	af := unix.AF_INET6
-	if isV4 {
+	if ip.Is4() {
 		af = unix.AF_INET
 	}
 	syscall.ForkLock.RLock()
@@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
 
 	//TODO: support multiple listening IPs (for limiting ipv6)
 	var sa unix.Sockaddr
-	if isV4 {
+	if ip.Is4() {
 		sa4 := &unix.SockaddrInet4{Port: port}
-		copy(sa4.Addr[:], ipV4)
+		sa4.Addr = ip.As4()
 		sa = sa4
 	} else {
 		sa6 := &unix.SockaddrInet6{Port: port}
-		copy(sa6.Addr[:], ip.To16())
+		sa6.Addr = ip.As16()
 		sa = sa6
 	}
 	if err = unix.Bind(fd, sa); err != nil {
@@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//l.Println(v, err)
 
-	return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
+	return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
 }
 
 func (u *StdConn) Rebind() error {
@@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 }
 
-func (u *StdConn) LocalAddr() (*Addr, error) {
+func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 	sa, err := unix.Getsockname(u.sysFd)
 	if err != nil {
-		return nil, err
+		return netip.AddrPort{}, err
 	}
 
-	addr := &Addr{}
 	switch sa := sa.(type) {
 	case *unix.SockaddrInet4:
-		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
-		addr.Port = uint16(sa.Port)
+		return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
+
 	case *unix.SockaddrInet6:
-		addr.IP = sa.Addr[0:]
-		addr.Port = uint16(sa.Port)
-	}
+		return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
 
-	return addr, nil
+	default:
+		return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
+	}
 }
 
 func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
 	plaintext := make([]byte, MTU)
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
-	udpAddr := &Addr{}
+	var ip netip.Addr
 	nb := make([]byte, 12, 12)
 
 	//TODO: should we track this?
@@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 		//metric.Update(int64(n))
 		for i := 0; i < n; i++ {
 			if u.isV4 {
-				udpAddr.IP = names[i][4:8]
+				ip, _ = netip.AddrFromSlice(names[i][4:8])
+				//TODO: IPV6-WORK what is not ok?
 			} else {
-				udpAddr.IP = names[i][8:24]
+				ip, _ = netip.AddrFromSlice(names[i][8:24])
+				//TODO: IPV6-WORK what is not ok?
 			}
-			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+			r(
+				netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
+				plaintext[:0],
+				buffers[i][:msgs[i].Len],
+				h,
+				fwPacket,
+				lhf,
+				nb,
+				q,
+				cache.Get(u.l),
+			)
 		}
 	}
 }
@@ -197,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
 	}
 }
 
-func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
+func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
 	if u.isV4 {
-		return u.writeTo4(b, addr)
+		return u.writeTo4(b, ip)
 	}
-	return u.writeTo6(b, addr)
+	return u.writeTo6(b, ip)
 }
 
-func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
+func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
+	rsa.Addr = ip.Addr().As16()
+	port := ip.Port()
 	// Little Endian -> Network Endian
-	rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
-	copy(rsa.Addr[:], addr.IP.To16())
+	rsa.Port = (port >> 8) | ((port & 0xff) << 8)
 
 	for {
 		_, _, err := unix.Syscall6(
@@ -232,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
 	}
 }
 
-func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
-	addrV4, isAddrV4 := maybeIPV4(addr.IP)
-	if !isAddrV4 {
+func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
+	if !ip.Addr().Is4() {
 		return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
 	}
 
 	var rsa unix.RawSockaddrInet4
 	rsa.Family = unix.AF_INET
+	rsa.Addr = ip.Addr().As4()
+	port := ip.Port()
 	// Little Endian -> Network Endian
-	rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
-	copy(rsa.Addr[:], addrV4)
+	rsa.Port = (port >> 8) | ((port & 0xff) << 8)
 
 	for {
 		_, _, err := unix.Syscall6(

+ 2 - 1
udp/udp_netbsd.go

@@ -8,13 +8,14 @@ package udp
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 )
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 

+ 22 - 21
udp/udp_rio_windows.go

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"net/netip"
 	"sync"
 	"sync/atomic"
 	"syscall"
@@ -61,16 +62,14 @@ type RIOConn struct {
 	results [packetsPerRing]winrio.Result
 }
 
-func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
+func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
 	if !winrio.Initialize() {
 		return nil, errors.New("could not initialize winrio")
 	}
 
 	u := &RIOConn{l: l}
 
-	addr := [16]byte{}
-	copy(addr[:], ip.To16())
-	err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
+	err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port})
 	if err != nil {
 		return nil, fmt.Errorf("bind: %w", err)
 	}
@@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 	buffer := make([]byte, MTU)
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
-	udpAddr := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 
 	for {
@@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 			return
 		}
 
-		udpAddr.IP = rua.Addr[:]
-		p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
-		p[0] = byte(rua.Port >> 8)
-		p[1] = byte(rua.Port)
-		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(
+			netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
+			plaintext[:0],
+			buffer[:n],
+			h,
+			fwPacket,
+			lhf,
+			nb,
+			q,
+			cache.Get(u.l),
+		)
 	}
 }
 
@@ -231,7 +235,7 @@ retry:
 	return n, ep, nil
 }
 
-func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
+func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
 	if !u.isOpen.Load() {
 		return net.ErrClosed
 	}
@@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
 
 	packet := u.tx.Push()
 	packet.addr.Family = windows.AF_INET6
-	p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
-	p[0] = byte(addr.Port >> 8)
-	p[1] = byte(addr.Port)
-	copy(packet.addr.Addr[:], addr.IP.To16())
+	packet.addr.Addr = ip.Addr().As16()
+	port := ip.Port()
+	packet.addr.Port = (port >> 8) | ((port & 0xff) << 8)
 	copy(packet.data[:], buf)
 
 	dataBuffer := &winrio.Buffer{
@@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
 	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 }
 
-func (u *RIOConn) LocalAddr() (*Addr, error) {
+func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
 	sa, err := windows.Getsockname(u.sock)
 	if err != nil {
-		return nil, err
+		return netip.AddrPort{}, err
 	}
 
 	v6 := sa.(*windows.SockaddrInet6)
-	return &Addr{
-		IP:   v6.Addr[:],
-		Port: uint16(v6.Port),
-	}, nil
+	return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil
+
 }
 
 func (u *RIOConn) Rebind() error {

+ 17 - 32
udp/udp_tester.go

@@ -4,9 +4,8 @@
 package udp
 
 import (
-	"fmt"
 	"io"
-	"net"
+	"net/netip"
 	"sync/atomic"
 
 	"github.com/sirupsen/logrus"
@@ -16,30 +15,24 @@ import (
 )
 
 type Packet struct {
-	ToIp     net.IP
-	ToPort   uint16
-	FromIp   net.IP
-	FromPort uint16
-	Data     []byte
+	To   netip.AddrPort
+	From netip.AddrPort
+	Data []byte
 }
 
 func (u *Packet) Copy() *Packet {
 	n := &Packet{
-		ToIp:     make(net.IP, len(u.ToIp)),
-		ToPort:   u.ToPort,
-		FromIp:   make(net.IP, len(u.FromIp)),
-		FromPort: u.FromPort,
-		Data:     make([]byte, len(u.Data)),
+		To:   u.To,
+		From: u.From,
+		Data: make([]byte, len(u.Data)),
 	}
 
-	copy(n.ToIp, u.ToIp)
-	copy(n.FromIp, u.FromIp)
 	copy(n.Data, u.Data)
 	return n
 }
 
 type TesterConn struct {
-	Addr *Addr
+	Addr netip.AddrPort
 
 	RxPackets chan *Packet // Packets to receive into nebula
 	TxPackets chan *Packet // Packets transmitted outside by nebula
@@ -48,9 +41,9 @@ type TesterConn struct {
 	l      *logrus.Logger
 }
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
 	return &TesterConn{
-		Addr:      &Addr{ip, uint16(port)},
+		Addr:      netip.AddrPortFrom(ip, uint16(port)),
 		RxPackets: make(chan *Packet, 10),
 		TxPackets: make(chan *Packet, 10),
 		l:         l,
@@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) {
 	}
 	if u.l.Level >= logrus.DebugLevel {
 		u.l.WithField("header", h).
-			WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
+			WithField("udpAddr", packet.From).
 			WithField("dataLen", len(packet.Data)).
 			Debug("UDP receiving injected packet")
 	}
@@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet {
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 
-func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
+func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	if u.closed.Load() {
 		return io.ErrClosedPipe
 	}
 
 	p := &Packet{
-		Data:     make([]byte, len(b), len(b)),
-		FromIp:   make([]byte, 16),
-		FromPort: u.Addr.Port,
-		ToIp:     make([]byte, 16),
-		ToPort:   addr.Port,
+		Data: make([]byte, len(b), len(b)),
+		From: u.Addr,
+		To:   addr,
 	}
 
 	copy(p.Data, b)
-	copy(p.ToIp, addr.IP.To16())
-	copy(p.FromIp, u.Addr.IP.To16())
-
 	u.TxPackets <- p
 	return nil
 }
@@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
 	plaintext := make([]byte, MTU)
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
-	ua := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 
 	for {
@@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
 		if !ok {
 			return
 		}
-		ua.Port = p.FromPort
-		copy(ua.IP, p.FromIp.To16())
-		r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 }
 
@@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() {
 	return func() {}
 }
 
-func (u *TesterConn) LocalAddr() (*Addr, error) {
+func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
 	return u.Addr, nil
 }
 

+ 2 - 1
udp/udp_windows.go

@@ -6,12 +6,13 @@ package udp
 import (
 	"fmt"
 	"net"
+	"net/netip"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
 )
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	if multi {
 		//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
 		// The udp stack would need to be reworked to hide away the implementation differences between