Jelajahi Sumber

Switch most everything to netip in prep for ipv6 in the overlay (#1173)

Nate Brown 1 tahun lalu
induk
melakukan
e264a0ff88
79 mengubah file dengan 1896 tambahan dan 2678 penghapusan
  1. 26 65
      allow_list.go
  2. 21 21
      allow_list_test.go
  3. 41 25
      calculated_remote.go
  4. 7 9
      calculated_remote_test.go
  5. 0 10
      cidr/parse.go
  6. 0 203
      cidr/tree4.go
  7. 0 170
      cidr/tree4_test.go
  8. 0 189
      cidr/tree6.go
  9. 0 98
      cidr/tree6_test.go
  10. 17 13
      connection_manager.go
  11. 17 17
      connection_manager_test.go
  12. 19 21
      control.go
  13. 33 24
      control_test.go
  14. 20 27
      control_tester.go
  15. 12 6
      dns_server.go
  16. 169 169
      e2e/handshakes_test.go
  17. 15 8
      e2e/helpers.go
  18. 28 24
      e2e/helpers_test.go
  19. 4 4
      e2e/router/hostmap.go
  20. 36 63
      e2e/router/router.go
  21. 58 42
      firewall.go
  22. 3 4
      firewall/packet.go
  23. 74 73
      firewall_test.go
  24. 2 0
      go.mod
  25. 6 0
      go.sum
  26. 42 16
      handshake_ix.go
  27. 50 41
      handshake_manager.go
  28. 9 9
      handshake_manager_test.go
  29. 75 71
      hostmap.go
  30. 25 34
      hostmap_test.go
  31. 4 2
      hostmap_tester.go
  32. 20 24
      inside.go
  33. 36 11
      interface.go
  34. 2 0
      iputil/packet.go
  35. 0 93
      iputil/util.go
  36. 0 17
      iputil/util_test.go
  37. 212 188
      lighthouse.go
  38. 85 102
      lighthouse_test.go
  39. 22 8
      main.go
  40. 48 47
      outside.go
  41. 5 5
      outside_test.go
  42. 3 5
      overlay/device.go
  43. 19 25
      overlay/route.go
  44. 27 16
      overlay/route_test.go
  45. 5 5
      overlay/tun.go
  46. 9 10
      overlay/tun_android.go
  47. 41 18
      overlay/tun_darwin.go
  48. 6 6
      overlay/tun_disabled.go
  49. 11 12
      overlay/tun_freebsd.go
  50. 9 10
      overlay/tun_ios.go
  51. 56 35
      overlay/tun_linux.go
  52. 14 15
      overlay/tun_netbsd.go
  53. 14 15
      overlay/tun_openbsd.go
  54. 9 10
      overlay/tun_tester.go
  55. 11 11
      overlay/tun_water_windows.go
  56. 3 3
      overlay/tun_windows.go
  57. 12 38
      overlay/tun_wintun_windows.go
  58. 7 8
      overlay/user.go
  59. 2 0
      pki.go
  60. 52 31
      relay_manager.go
  61. 73 93
      remote_list.go
  62. 98 89
      remote_list_test.go
  63. 1 1
      service/service.go
  64. 6 10
      service/service_test.go
  65. 29 36
      ssh.go
  66. 5 7
      test/tun.go
  67. 5 4
      timeout_test.go
  68. 8 6
      udp/conn.go
  69. 3 2
      udp/temp.go
  70. 0 100
      udp/udp_all.go
  71. 2 1
      udp/udp_android.go
  72. 2 1
      udp/udp_bsd.go
  73. 2 1
      udp/udp_darwin.go
  74. 23 14
      udp/udp_generic.go
  75. 43 32
      udp/udp_linux.go
  76. 2 1
      udp/udp_netbsd.go
  77. 22 21
      udp/udp_rio_windows.go
  78. 17 32
      udp/udp_tester.go
  79. 2 1
      udp/udp_windows.go

+ 26 - 65
allow_list.go

@@ -2,17 +2,16 @@ package nebula
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"net"
+	"net/netip"
 	"regexp"
 	"regexp"
 
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type AllowList struct {
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *cidr.Tree6[bool]
+	cidrTree *bart.Table[bool]
 }
 }
 
 
 type RemoteAllowList struct {
 type RemoteAllowList struct {
@@ -20,7 +19,7 @@ type RemoteAllowList struct {
 
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
 	// are *AllowList
-	insideAllowLists *cidr.Tree6[*AllowList]
+	insideAllowLists *bart.Table[*AllowList]
 }
 }
 
 
 type LocalAllowList struct {
 type LocalAllowList struct {
@@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
 	}
 
 
-	tree := cidr.NewTree6[bool]()
+	tree := new(bart.Table[bool])
 
 
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	type allowListRules struct {
 	type allowListRules struct {
@@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
 			return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
 		}
 		}
 
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		ipNet, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
 		}
 		}
 
 
+		ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits())
+
 		// TODO: should we error on duplicate CIDRs in the config?
 		// TODO: should we error on duplicate CIDRs in the config?
-		tree.AddCIDR(ipNet, value)
+		tree.Insert(ipNet, value)
 
 
-		maskBits, maskSize := ipNet.Mask.Size()
+		maskBits := ipNet.Bits()
 
 
 		var rules *allowListRules
 		var rules *allowListRules
-		if maskSize == 32 {
+		if ipNet.Addr().Is4() {
 			rules = &rules4
 			rules = &rules4
 		} else {
 		} else {
 			rules = &rules6
 			rules = &rules6
@@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 
 	if !rules4.defaultSet {
 	if !rules4.defaultSet {
 		if rules4.allValuesMatch {
 		if rules4.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
-			tree.AddCIDR(zeroCIDR, !rules4.allValues)
+			tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues)
 		} else {
 		} else {
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
 		}
 		}
@@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 
 
 	if !rules6.defaultSet {
 	if !rules6.defaultSet {
 		if rules6.allValuesMatch {
 		if rules6.allValuesMatch {
-			_, zeroCIDR, _ := net.ParseCIDR("::/0")
-			tree.AddCIDR(zeroCIDR, !rules6.allValues)
+			tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues)
 		} else {
 		} else {
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
 			return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
 		}
 		}
@@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
 	return nameRules, nil
 	return nameRules, nil
 }
 }
 
 
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
+func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) {
 	value := c.Get(k)
 	value := c.Get(k)
 	if value == nil {
 	if value == nil {
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	remoteAllowRanges := cidr.NewTree6[*AllowList]()
+	remoteAllowRanges := new(bart.Table[*AllowList])
 
 
 	rawMap, ok := value.(map[interface{}]interface{})
 	rawMap, ok := value.(map[interface{}]interface{})
 	if !ok {
 	if !ok {
@@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error
 			return nil, err
 			return nil, err
 		}
 		}
 
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		ipNet, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
+			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err)
 		}
 		}
 
 
-		remoteAllowRanges.AddCIDR(ipNet, allowList)
+		remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList)
 	}
 	}
 
 
 	return remoteAllowRanges, nil
 	return remoteAllowRanges, nil
 }
 }
 
 
-func (al *AllowList) Allow(ip net.IP) bool {
-	if al == nil {
-		return true
-	}
-
-	_, result := al.cidrTree.MostSpecificContains(ip)
-	return result
-}
-
-func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
-	if al == nil {
-		return true
-	}
-
-	_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
-	return result
-}
-
-func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
+func (al *AllowList) Allow(ip netip.Addr) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
 
 
-	_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+	result, _ := al.cidrTree.Lookup(ip)
 	return result
 	return result
 }
 }
 
 
-func (al *LocalAllowList) Allow(ip net.IP) bool {
+func (al *LocalAllowList) Allow(ip netip.Addr) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
@@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 	return !al.nameRules[0].Allow
 }
 }
 
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool {
+func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
 	if al == nil {
 	if al == nil {
 		return true
 		return true
 	}
 	}
 	return al.AllowList.Allow(ip)
 	return al.AllowList.Allow(ip)
 }
 }
 
 
-func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool {
+func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
 	if !al.getInsideAllowList(vpnIp).Allow(ip) {
 	if !al.getInsideAllowList(vpnIp).Allow(ip) {
 		return false
 		return false
 	}
 	}
 	return al.AllowList.Allow(ip)
 	return al.AllowList.Allow(ip)
 }
 }
 
 
-func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool {
-	if al == nil {
-		return true
-	}
-	if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) {
-		return false
-	}
-	return al.AllowList.AllowIpV4(ip)
-}
-
-func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
-	if al == nil {
-		return true
-	}
-	if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) {
-		return false
-	}
-	return al.AllowList.AllowIpV6(hi, lo)
-}
-
-func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
+func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
 	if al.insideAllowLists != nil {
-		ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnIp)
 		if ok {
 		if ok {
 			return inside
 			return inside
 		}
 		}

+ 21 - 21
allow_list_test.go

@@ -1,11 +1,11 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"net"
+	"net/netip"
 	"regexp"
 	"regexp"
 	"testing"
 	"testing"
 
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
@@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 		"192.168.0.0": true,
 		"192.168.0.0": true,
 	}
 	}
 	r, err := newAllowListFromConfig(c, "allowlist", nil)
 	r, err := newAllowListFromConfig(c, "allowlist", nil)
-	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0")
+	assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
 	assert.Nil(t, r)
 	assert.Nil(t, r)
 
 
 	c.Settings["allowlist"] = map[interface{}]interface{}{
 	c.Settings["allowlist"] = map[interface{}]interface{}{
@@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) {
 }
 }
 
 
 func TestAllowList_Allow(t *testing.T) {
 func TestAllowList_Allow(t *testing.T) {
-	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
-
-	tree := cidr.NewTree6[bool]()
-	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
-	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
-	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
-	tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true)
-	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true)
-	tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false)
-	tree.AddCIDR(cidr.Parse("::1/128"), true)
-	tree.AddCIDR(cidr.Parse("::2/128"), false)
+	assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1")))
+
+	tree := new(bart.Table[bool])
+	tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true)
+	tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false)
+	tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true)
+	tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false)
+	tree.Insert(netip.MustParsePrefix("::1/128"), true)
+	tree.Insert(netip.MustParsePrefix("::2/128"), false)
 	al := &AllowList{cidrTree: tree}
 	al := &AllowList{cidrTree: tree}
 
 
-	assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1")))
-	assert.Equal(t, true, al.Allow(net.ParseIP("::1")))
-	assert.Equal(t, false, al.Allow(net.ParseIP("::2")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1")))
+	assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42")))
+	assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1")))
+	assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1")))
+	assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2")))
 }
 }
 
 
 func TestLocalAllowList_AllowName(t *testing.T) {
 func TestLocalAllowList_AllowName(t *testing.T) {

+ 41 - 25
calculated_remote.go

@@ -1,41 +1,36 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"encoding/binary"
 	"fmt"
 	"fmt"
 	"math"
 	"math"
 	"net"
 	"net"
+	"net/netip"
 	"strconv"
 	"strconv"
 
 
-	"github.com/slackhq/nebula/cidr"
+	"github.com/gaissmai/bart"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 // This allows us to "guess" what the remote might be for a host while we wait
 // This allows us to "guess" what the remote might be for a host while we wait
 // for the lighthouse response. See "lighthouse.calculated_remotes" in the
 // for the lighthouse response. See "lighthouse.calculated_remotes" in the
 // example config file.
 // example config file.
 type calculatedRemote struct {
 type calculatedRemote struct {
-	ipNet  net.IPNet
-	maskIP iputil.VpnIp
-	mask   iputil.VpnIp
-	port   uint32
+	ipNet netip.Prefix
+	mask  netip.Prefix
+	port  uint32
 }
 }
 
 
-func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) {
-	// Ensure this is an IPv4 mask that we expect
-	ones, bits := ipNet.Mask.Size()
-	if ones == 0 || bits != 32 {
-		return nil, fmt.Errorf("invalid mask: %v", ipNet)
-	}
+func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) {
+	masked := maskCidr.Masked()
 	if port < 0 || port > math.MaxUint16 {
 	if port < 0 || port > math.MaxUint16 {
 		return nil, fmt.Errorf("invalid port: %d", port)
 		return nil, fmt.Errorf("invalid port: %d", port)
 	}
 	}
 
 
 	return &calculatedRemote{
 	return &calculatedRemote{
-		ipNet:  *ipNet,
-		maskIP: iputil.Ip2VpnIp(ipNet.IP),
-		mask:   iputil.Ip2VpnIp(ipNet.Mask),
-		port:   uint32(port),
+		ipNet: maskCidr,
+		mask:  masked,
+		port:  uint32(port),
 	}, nil
 	}, nil
 }
 }
 
 
@@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string {
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 	return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port)
 }
 }
 
 
-func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
+func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort {
 	// Combine the masked bytes of the "mask" IP with the unmasked bytes
 	// Combine the masked bytes of the "mask" IP with the unmasked bytes
 	// of the overlay IP
 	// of the overlay IP
-	masked := (c.maskIP & c.mask) | (ip & ^c.mask)
+	if c.ipNet.Addr().Is4() {
+		return c.apply4(ip)
+	}
+	return c.apply6(ip)
+}
+
+func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort {
+	//TODO: IPV6-WORK this can be less crappy
+	maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen())
+	mask := binary.BigEndian.Uint32(maskb[:])
+
+	b := c.mask.Addr().As4()
+	maskIp := binary.BigEndian.Uint32(b[:])
+
+	b = ip.As4()
+	intIp := binary.BigEndian.Uint32(b[:])
+
+	return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port}
+}
 
 
-	return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
+func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort {
+	//TODO: IPV6-WORK
+	panic("Can not calculate ipv6 remote addresses")
 }
 }
 
 
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) {
 	value := c.Get(k)
 	value := c.Get(k)
 	if value == nil {
 	if value == nil {
 		return nil, nil
 		return nil, nil
 	}
 	}
 
 
-	calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
+	calculatedRemotes := new(bart.Table[[]*calculatedRemote])
 
 
 	rawMap, ok := value.(map[any]any)
 	rawMap, ok := value.(map[any]any)
 	if !ok {
 	if !ok {
@@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu
 			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
 			return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
 		}
 		}
 
 
-		_, ipNet, err := net.ParseCIDR(rawCIDR)
+		cidr, err := netip.ParsePrefix(rawCIDR)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 			return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
 		}
 		}
 
 
+		//TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here
 		entry, err := newCalculatedRemotesListFromConfig(rawValue)
 		entry, err := newCalculatedRemotesListFromConfig(rawValue)
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 			return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err)
 		}
 		}
 
 
-		calculatedRemotes.AddCIDR(ipNet, entry)
+		calculatedRemotes.Insert(cidr, entry)
 	}
 	}
 
 
 	return calculatedRemotes, nil
 	return calculatedRemotes, nil
@@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 	if !ok {
 	if !ok {
 		return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
 		return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue)
 	}
 	}
-	_, ipNet, err := net.ParseCIDR(rawMask)
+	maskCidr, err := netip.ParsePrefix(rawMask)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("invalid mask: %s", rawMask)
 		return nil, fmt.Errorf("invalid mask: %s", rawMask)
 	}
 	}
@@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) {
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 		return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue)
 	}
 	}
 
 
-	return newCalculatedRemote(ipNet, port)
+	return newCalculatedRemote(maskCidr, port)
 }
 }

+ 7 - 9
calculated_remote_test.go

@@ -1,27 +1,25 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 
 
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/require"
 )
 )
 
 
 func TestCalculatedRemoteApply(t *testing.T) {
 func TestCalculatedRemoteApply(t *testing.T) {
-	_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
+	ipNet, err := netip.ParsePrefix("192.168.1.0/24")
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
 	c, err := newCalculatedRemote(ipNet, 4242)
 	c, err := newCalculatedRemote(ipNet, 4242)
 	require.NoError(t, err)
 	require.NoError(t, err)
 
 
-	input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182})
+	input, err := netip.ParseAddr("10.0.10.182")
+	assert.NoError(t, err)
 
 
-	expected := &Ip4AndPort{
-		Ip:   uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})),
-		Port: 4242,
-	}
+	expected, err := netip.ParseAddr("192.168.1.182")
+	assert.NoError(t, err)
 
 
-	assert.Equal(t, expected, c.Apply(input))
+	assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input))
 }
 }

+ 0 - 10
cidr/parse.go

@@ -1,10 +0,0 @@
-package cidr
-
-import "net"
-
-// Parse is a convenience function that returns only the IPNet
-// This function ignores errors since it is primarily a test helper, the result could be nil
-func Parse(s string) *net.IPNet {
-	_, c, _ := net.ParseCIDR(s)
-	return c
-}

+ 0 - 203
cidr/tree4.go

@@ -1,203 +0,0 @@
-package cidr
-
-import (
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-type Node[T any] struct {
-	left     *Node[T]
-	right    *Node[T]
-	parent   *Node[T]
-	hasValue bool
-	value    T
-}
-
-type entry[T any] struct {
-	CIDR  *net.IPNet
-	Value T
-}
-
-type Tree4[T any] struct {
-	root *Node[T]
-	list []entry[T]
-}
-
-const (
-	startbit = iputil.VpnIp(0x80000000)
-)
-
-func NewTree4[T any]() *Tree4[T] {
-	tree := new(Tree4[T])
-	tree.root = &Node[T]{}
-	tree.list = []entry[T]{}
-	return tree
-}
-
-func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
-	bit := startbit
-	node := tree.root
-	next := tree.root
-
-	ip := iputil.Ip2VpnIp(cidr.IP)
-	mask := iputil.Ip2VpnIp(cidr.Mask)
-
-	// Find our last ancestor in the tree
-	for bit&mask != 0 {
-		if ip&bit != 0 {
-			next = node.right
-		} else {
-			next = node.left
-		}
-
-		if next == nil {
-			break
-		}
-
-		bit = bit >> 1
-		node = next
-	}
-
-	// We already have this range so update the value
-	if next != nil {
-		addCIDR := cidr.String()
-		for i, v := range tree.list {
-			if addCIDR == v.CIDR.String() {
-				tree.list = append(tree.list[:i], tree.list[i+1:]...)
-				break
-			}
-		}
-
-		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-		node.value = val
-		node.hasValue = true
-		return
-	}
-
-	// Build up the rest of the tree we don't already have
-	for bit&mask != 0 {
-		next = &Node[T]{}
-		next.parent = node
-
-		if ip&bit != 0 {
-			node.right = next
-		} else {
-			node.left = next
-		}
-
-		bit >>= 1
-		node = next
-	}
-
-	// Final node marks our cidr, set the value
-	node.value = val
-	node.hasValue = true
-	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
-}
-
-// Contains finds the first match, which may be the least specific
-func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			return true, node.value
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-
-	}
-
-	return false, value
-}
-
-// MostSpecificContains finds the most specific match
-func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			value = node.value
-			ok = true
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return ok, value
-}
-
-type eachFunc[T any] func(T) bool
-
-// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
-// The final return value will be true if the provided function returned true
-func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
-	bit := startbit
-	node := tree.root
-
-	for node != nil {
-		if node.hasValue {
-			// If the each func returns true then we can exit the loop
-			if each(node.value) {
-				return true
-			}
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return false
-}
-
-// GetCIDR returns the entry added by the most recent matching AddCIDR call
-func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
-	bit := startbit
-	node := tree.root
-
-	ip := iputil.Ip2VpnIp(cidr.IP)
-	mask := iputil.Ip2VpnIp(cidr.Mask)
-
-	// Find our last ancestor in the tree
-	for node != nil && bit&mask != 0 {
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit = bit >> 1
-	}
-
-	if bit&mask == 0 && node != nil {
-		value = node.value
-		ok = node.hasValue
-	}
-
-	return ok, value
-}
-
-// List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4[T]) List() []entry[T] {
-	return tree.list
-}

+ 0 - 170
cidr/tree4_test.go

@@ -1,170 +0,0 @@
-package cidr
-
-import (
-	"net"
-	"testing"
-
-	"github.com/slackhq/nebula/iputil"
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDRTree_List(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
-	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
-	tree.AddCIDR(Parse("1.0.0.0/16"), "4")
-	list := tree.List()
-	assert.Len(t, list, 2)
-	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
-	assert.Equal(t, "2", list[0].Value)
-	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
-	assert.Equal(t, "4", list[1].Value)
-}
-
-func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
-	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4a", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-}
-
-func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.0/30"), "4b")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4b", "4.1.1.2"},
-		{true, "4c", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-}
-
-func TestTree4_GetCIDR(t *testing.T) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
-	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IPNet  *net.IPNet
-	}{
-		{true, "1", Parse("1.0.0.0/8")},
-		{true, "2", Parse("2.1.0.0/16")},
-		{true, "3", Parse("3.1.1.0/24")},
-		{true, "4a", Parse("4.1.1.0/24")},
-		{true, "4b", Parse("4.1.1.1/32")},
-		{true, "4c", Parse("4.1.2.1/32")},
-		{true, "5", Parse("254.0.0.0/4")},
-		{false, "", Parse("2.0.0.0/8")},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.GetCIDR(tt.IPNet)
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-}
-
-func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
-
-	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Contains(ip)
-		}
-	})
-}

+ 0 - 189
cidr/tree6.go

@@ -1,189 +0,0 @@
-package cidr
-
-import (
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-const startbit6 = uint64(1 << 63)
-
-type Tree6[T any] struct {
-	root4 *Node[T]
-	root6 *Node[T]
-}
-
-func NewTree6[T any]() *Tree6[T] {
-	tree := new(Tree6[T])
-	tree.root4 = &Node[T]{}
-	tree.root6 = &Node[T]{}
-	return tree
-}
-
-func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
-	var node, next *Node[T]
-
-	cidrIP, ipv4 := isIPV4(cidr.IP)
-	if ipv4 {
-		node = tree.root4
-		next = tree.root4
-
-	} else {
-		node = tree.root6
-		next = tree.root6
-	}
-
-	for i := 0; i < len(cidrIP); i += 4 {
-		ip := iputil.Ip2VpnIp(cidrIP[i : i+4])
-		mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4])
-		bit := startbit
-
-		// Find our last ancestor in the tree
-		for bit&mask != 0 {
-			if ip&bit != 0 {
-				next = node.right
-			} else {
-				next = node.left
-			}
-
-			if next == nil {
-				break
-			}
-
-			bit = bit >> 1
-			node = next
-		}
-
-		// Build up the rest of the tree we don't already have
-		for bit&mask != 0 {
-			next = &Node[T]{}
-			next.parent = node
-
-			if ip&bit != 0 {
-				node.right = next
-			} else {
-				node.left = next
-			}
-
-			bit >>= 1
-			node = next
-		}
-	}
-
-	// Final node marks our cidr, set the value
-	node.value = val
-	node.hasValue = true
-}
-
-// Finds the most specific match
-func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
-	var node *Node[T]
-
-	wholeIP, ipv4 := isIPV4(ip)
-	if ipv4 {
-		node = tree.root4
-	} else {
-		node = tree.root6
-	}
-
-	for i := 0; i < len(wholeIP); i += 4 {
-		ip := iputil.Ip2VpnIp(wholeIP[i : i+4])
-		bit := startbit
-
-		for node != nil {
-			if node.hasValue {
-				value = node.value
-				ok = true
-			}
-
-			if bit == 0 {
-				break
-			}
-
-			if ip&bit != 0 {
-				node = node.right
-			} else {
-				node = node.left
-			}
-
-			bit >>= 1
-		}
-	}
-
-	return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
-	bit := startbit
-	node := tree.root4
-
-	for node != nil {
-		if node.hasValue {
-			value = node.value
-			ok = true
-		}
-
-		if ip&bit != 0 {
-			node = node.right
-		} else {
-			node = node.left
-		}
-
-		bit >>= 1
-	}
-
-	return ok, value
-}
-
-func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
-	ip := hi
-	node := tree.root6
-
-	for i := 0; i < 2; i++ {
-		bit := startbit6
-
-		for node != nil {
-			if node.hasValue {
-				value = node.value
-				ok = true
-			}
-
-			if bit == 0 {
-				break
-			}
-
-			if ip&bit != 0 {
-				node = node.right
-			} else {
-				node = node.left
-			}
-
-			bit >>= 1
-		}
-
-		ip = lo
-	}
-
-	return ok, value
-}
-
-func isIPV4(ip net.IP) (net.IP, bool) {
-	if len(ip) == net.IPv4len {
-		return ip, true
-	}
-
-	if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff {
-		return ip[12:16], true
-	}
-
-	return ip, false
-}
-
-func isZeros(p net.IP) bool {
-	for i := 0; i < len(p); i++ {
-		if p[i] != 0 {
-			return false
-		}
-	}
-	return true
-}

+ 0 - 98
cidr/tree6_test.go

@@ -1,98 +0,0 @@
-package cidr
-
-import (
-	"encoding/binary"
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewTree6[string]()
-	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
-	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
-	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
-	tree.AddCIDR(Parse("4.1.1.1/24"), "4a")
-	tree.AddCIDR(Parse("4.1.1.1/30"), "4b")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "4c")
-	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "1", "1.0.0.0"},
-		{true, "1", "1.255.255.255"},
-		{true, "2", "2.1.0.0"},
-		{true, "2", "2.1.255.255"},
-		{true, "3", "3.1.1.0"},
-		{true, "3", "3.1.1.255"},
-		{true, "4a", "4.1.1.255"},
-		{true, "4b", "4.1.1.2"},
-		{true, "4c", "4.1.1.1"},
-		{true, "5", "240.0.0.0"},
-		{true, "5", "255.255.255.255"},
-		{true, "6a", "1:2:0:4:1:1:1:1"},
-		{true, "6b", "1:2:0:4:5:1:1:1"},
-		{true, "6c", "1:2:0:4:5:0:0:0"},
-		{false, "", "239.0.0.0"},
-		{false, "", "4.1.2.2"},
-	}
-
-	for _, tt := range tests {
-		ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-
-	tree = NewTree6[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	tree.AddCIDR(Parse("::/0"), "cool6")
-	ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("::"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool6", r)
-
-	ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
-	assert.True(t, ok)
-	assert.Equal(t, "cool6", r)
-}
-
-func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewTree6[string]()
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
-	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
-
-	tests := []struct {
-		Found  bool
-		Result interface{}
-		IP     string
-	}{
-		{true, "6a", "1:2:0:4:1:1:1:1"},
-		{true, "6b", "1:2:0:4:5:1:1:1"},
-		{true, "6c", "1:2:0:4:5:0:0:0"},
-	}
-
-	for _, tt := range tests {
-		ip := net.ParseIP(tt.IP)
-		hi := binary.BigEndian.Uint64(ip[:8])
-		lo := binary.BigEndian.Uint64(ip[8:])
-
-		ok, r := tree.MostSpecificContainsIpV6(hi, lo)
-		assert.Equal(t, tt.Found, ok)
-		assert.Equal(t, tt.Result, r)
-	}
-}

+ 17 - 13
connection_manager.go

@@ -3,6 +3,8 @@ package nebula
 import (
 import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
+	"encoding/binary"
+	"net/netip"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
@@ -10,8 +12,6 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 type trafficDecision int
 type trafficDecision int
@@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
 		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
 
 
 		var index uint32
 		var index uint32
-		var relayFrom iputil.VpnIp
-		var relayTo iputil.VpnIp
+		var relayFrom netip.Addr
+		var relayTo netip.Addr
 		switch {
 		switch {
 		case ok && existing.State == Established:
 		case ok && existing.State == Established:
 			// This relay already exists in newhostinfo, then do nothing.
 			// This relay already exists in newhostinfo, then do nothing.
@@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			index = existing.LocalIndex
 			index = existing.LocalIndex
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnIp
+				relayFrom = n.intf.myVpnNet.Addr()
 				relayTo = existing.PeerIp
 				relayTo = existing.PeerIp
 			case ForwardingType:
 			case ForwardingType:
 				relayFrom = existing.PeerIp
 				relayFrom = existing.PeerIp
@@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			}
 			}
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnIp
+				relayFrom = n.intf.myVpnNet.Addr()
 				relayTo = r.PeerIp
 				relayTo = r.PeerIp
 			case ForwardingType:
 			case ForwardingType:
 				relayFrom = r.PeerIp
 				relayFrom = r.PeerIp
@@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			}
 			}
 		}
 		}
 
 
+		//TODO: IPV6-WORK
+		relayFromB := relayFrom.As4()
+		relayToB := relayTo.As4()
+
 		// Send a CreateRelayRequest to the peer.
 		// Send a CreateRelayRequest to the peer.
 		req := NebulaControl{
 		req := NebulaControl{
 			Type:                NebulaControl_CreateRelayRequest,
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
 			InitiatorRelayIndex: index,
-			RelayFromIp:         uint32(relayFrom),
-			RelayToIp:           uint32(relayTo),
+			RelayFromIp:         binary.BigEndian.Uint32(relayFromB[:]),
+			RelayToIp:           binary.BigEndian.Uint32(relayToB[:]),
 		}
 		}
 		msg, err := req.Marshal()
 		msg, err := req.Marshal()
 		if err != nil {
 		if err != nil {
@@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		} else {
 		} else {
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
 			n.l.WithFields(logrus.Fields{
 			n.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(req.RelayFromIp),
-				"relayTo":             iputil.VpnIp(req.RelayToIp),
+				"relayFrom":           req.RelayFromIp,
+				"relayTo":             req.RelayToIp,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,
 				"vpnIp":               newhostinfo.vpnIp}).
 				"vpnIp":               newhostinfo.vpnIp}).
@@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 	// Let's sort this out.
 
 
-	if current.vpnIp < n.intf.myVpnIp {
+	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// The remotes vpn ip is lower than mine. I will not flip.
 		// The remotes vpn ip is lower than mine. I will not flip.
@@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 	}
 	}
 
 
 	if n.punchy.GetTargetEverything() {
 	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
+		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
 			n.metricsTxPunchy.Inc(1)
 			n.metricsTxPunchy.Inc(1)
 			n.intf.outside.WriteTo([]byte{1}, addr)
 			n.intf.outside.WriteTo([]byte{1}, addr)
 		})
 		})
 
 
-	} else if hostinfo.remote != nil {
+	} else if hostinfo.remote.IsValid() {
 		n.metricsTxPunchy.Inc(1)
 		n.metricsTxPunchy.Inc(1)
 		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 	}

+ 17 - 17
connection_manager_test.go

@@ -5,28 +5,26 @@ import (
 	"crypto/ed25519"
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/rand"
 	"net"
 	"net"
+	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-var vpnIp iputil.VpnIp
-
 func newTestLighthouse() *LightHouse {
 func newTestLighthouse() *LightHouse {
 	lh := &LightHouse{
 	lh := &LightHouse{
 		l:         test.NewLogger(),
 		l:         test.NewLogger(),
-		addrMap:   map[iputil.VpnIp]*RemoteList{},
-		queryChan: make(chan iputil.VpnIp, 10),
+		addrMap:   map[netip.Addr]*RemoteList{},
+		queryChan: make(chan netip.Addr, 10),
 	}
 	}
-	lighthouses := map[iputil.VpnIp]struct{}{}
-	staticList := map[iputil.VpnIp]struct{}{}
+	lighthouses := map[netip.Addr]struct{}{}
+	staticList := map[netip.Addr]struct{}{}
 
 
 	lh.lighthouses.Store(&lighthouses)
 	lh.lighthouses.Store(&lighthouses)
 	lh.staticList.Store(&staticList)
 	lh.staticList.Store(&staticList)
@@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse {
 func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
 	hostMap := newHostMap(l, vpncidr)
 	hostMap := newHostMap(l, vpncidr)
@@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 func Test_NewConnectionManagerTest2(t *testing.T) {
 func Test_NewConnectionManagerTest2(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
 	//_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 
 
 	// Very incomplete mock objects
 	// Very incomplete mock objects
 	hostMap := newHostMap(l, vpncidr)
 	hostMap := newHostMap(l, vpncidr)
@@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		IP:   net.IPv4(172, 1, 1, 2),
 		IP:   net.IPv4(172, 1, 1, 2),
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 	}
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
 	hostMap := newHostMap(l, vpncidr)
 	hostMap := newHostMap(l, vpncidr)
 	hostMap.preferredRanges.Store(&preferredRanges)
 	hostMap.preferredRanges.Store(&preferredRanges)
 
 

+ 19 - 21
control.go

@@ -2,7 +2,7 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"syscall"
 	"syscall"
@@ -10,9 +10,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/overlay"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
@@ -21,10 +19,10 @@ import (
 type controlEach func(h *HostInfo)
 type controlEach func(h *HostInfo)
 
 
 type controlHostLister interface {
 type controlHostLister interface {
-	QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+	QueryVpnIp(vpnIp netip.Addr) *HostInfo
 	ForEachIndex(each controlEach)
 	ForEachIndex(each controlEach)
 	ForEachVpnIp(each controlEach)
 	ForEachVpnIp(each controlEach)
-	GetPreferredRanges() []*net.IPNet
+	GetPreferredRanges() []netip.Prefix
 }
 }
 
 
 type Control struct {
 type Control struct {
@@ -39,15 +37,15 @@ type Control struct {
 }
 }
 
 
 type ControlHostInfo struct {
 type ControlHostInfo struct {
-	VpnIp                  net.IP                  `json:"vpnIp"`
+	VpnIp                  netip.Addr              `json:"vpnIp"`
 	LocalIndex             uint32                  `json:"localIndex"`
 	LocalIndex             uint32                  `json:"localIndex"`
 	RemoteIndex            uint32                  `json:"remoteIndex"`
 	RemoteIndex            uint32                  `json:"remoteIndex"`
-	RemoteAddrs            []*udp.Addr             `json:"remoteAddrs"`
+	RemoteAddrs            []netip.AddrPort        `json:"remoteAddrs"`
 	Cert                   *cert.NebulaCertificate `json:"cert"`
 	Cert                   *cert.NebulaCertificate `json:"cert"`
 	MessageCounter         uint64                  `json:"messageCounter"`
 	MessageCounter         uint64                  `json:"messageCounter"`
-	CurrentRemote          *udp.Addr               `json:"currentRemote"`
-	CurrentRelaysToMe      []iputil.VpnIp          `json:"currentRelaysToMe"`
-	CurrentRelaysThroughMe []iputil.VpnIp          `json:"currentRelaysThroughMe"`
+	CurrentRemote          netip.AddrPort          `json:"currentRemote"`
+	CurrentRelaysToMe      []netip.Addr            `json:"currentRelaysToMe"`
+	CurrentRelaysThroughMe []netip.Addr            `json:"currentRelaysThroughMe"`
 }
 }
 
 
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
 // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock()
@@ -132,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 }
 }
 
 
 // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
 // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
-func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo {
 	var hl controlHostLister
 	var hl controlHostLister
 	if pending {
 	if pending {
 		hl = c.f.handshakeManager
 		hl = c.f.handshakeManager
@@ -150,19 +149,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 }
 }
 
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 // SetRemoteForTunnel forces a tunnel to use a specific remote
-func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo {
 	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
 	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return nil
 		return nil
 	}
 	}
 
 
-	hostInfo.SetRemote(addr.Copy())
+	hostInfo.SetRemote(addr)
 	ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
 	ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
 	return &ch
 	return &ch
 }
 }
 
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
-func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
+// Caller should take care to Unmap() any 4in6 addresses prior to calling.
+func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool {
 	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
 	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo == nil {
 	if hostInfo == nil {
 		return false
 		return false
@@ -205,7 +206,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	}
 	}
 
 
 	// Learn which hosts are being used as relays, so we can shut them down last.
 	// Learn which hosts are being used as relays, so we can shut them down last.
-	relayingHosts := map[iputil.VpnIp]*HostInfo{}
+	relayingHosts := map[netip.Addr]*HostInfo{}
 	// Grab the hostMap lock to access the Relays map
 	// Grab the hostMap lock to access the Relays map
 	c.f.hostMap.Lock()
 	c.f.hostMap.Lock()
 	for _, relayingHost := range c.f.hostMap.Relays {
 	for _, relayingHost := range c.f.hostMap.Relays {
@@ -236,15 +237,16 @@ func (c *Control) Device() overlay.Device {
 	return c.f.inside
 	return c.f.inside
 }
 }
 
 
-func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
+func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo {
 
 
 	chi := ControlHostInfo{
 	chi := ControlHostInfo{
-		VpnIp:                  h.vpnIp.ToIP(),
+		VpnIp:                  h.vpnIp,
 		LocalIndex:             h.localIndexId,
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
+		CurrentRemote:          h.remote,
 	}
 	}
 
 
 	if h.ConnectionState != nil {
 	if h.ConnectionState != nil {
@@ -255,10 +257,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 		chi.Cert = c.Copy()
 		chi.Cert = c.Copy()
 	}
 	}
 
 
-	if h.remote != nil {
-		chi.CurrentRemote = h.remote.Copy()
-	}
-
 	return chi
 	return chi
 }
 }
 
 

+ 33 - 24
control_test.go

@@ -2,15 +2,14 @@ package nebula
 
 
 import (
 import (
 	"net"
 	"net"
+	"net/netip"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
-	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
 	// To properly ensure we are not exposing core memory to the caller
-	hm := newHostMap(l, &net.IPNet{})
-	hm.preferredRanges.Store(&[]*net.IPNet{})
+	hm := newHostMap(l, netip.Prefix{})
+	hm.preferredRanges.Store(&[]netip.Prefix{})
+
+	remote1 := netip.MustParseAddrPort("0.0.0.100:4444")
+	remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444")
 
 
-	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
-	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
 	ipNet := net.IPNet{
-		IP:   net.IPv4(1, 2, 3, 4),
+		IP:   remote1.Addr().AsSlice(),
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 	}
 
 
 	ipNet2 := net.IPNet{
 	ipNet2 := net.IPNet{
-		IP:   net.ParseIP("1:2:3:4:5:6:7:8"),
+		IP:   remote2.Addr().AsSlice(),
 		Mask: net.IPMask{255, 255, 255, 0},
 		Mask: net.IPMask{255, 255, 255, 0},
 	}
 	}
 
 
@@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 	}
 
 
 	remotes := NewRemoteList(nil)
 	remotes := NewRemoteList(nil)
-	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
-	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
+	remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port()))
+	remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port()))
+
+	vpnIp, ok := netip.AddrFromSlice(ipNet.IP)
+	assert.True(t, ok)
+
 	hm.unlockedAddHostInfo(&HostInfo{
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remote:  remote1,
 		remotes: remotes,
 		remotes: remotes,
@@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		vpnIp:         iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp:         vpnIp,
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 		},
 	}, &Interface{})
 	}, &Interface{})
 
 
+	vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP)
+	assert.True(t, ok)
+
 	hm.unlockedAddHostInfo(&HostInfo{
 	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remote:  remote1,
 		remotes: remotes,
 		remotes: remotes,
@@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		},
 		},
 		remoteIndexId: 200,
 		remoteIndexId: 200,
 		localIndexId:  201,
 		localIndexId:  201,
-		vpnIp:         iputil.Ip2VpnIp(ipNet2.IP),
+		vpnIp:         vpnIp2,
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 		},
 	}, &Interface{})
 	}, &Interface{})
@@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		l: logrus.New(),
 		l: logrus.New(),
 	}
 	}
 
 
-	thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false)
+	thi := c.GetHostInfoByVpnIp(vpnIp, false)
 
 
 	expectedInfo := ControlHostInfo{
 	expectedInfo := ControlHostInfo{
-		VpnIp:                  net.IPv4(1, 2, 3, 4).To4(),
+		VpnIp:                  vpnIp,
 		LocalIndex:             201,
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteIndex:            200,
-		RemoteAddrs:            []*udp.Addr{remote2, remote1},
+		RemoteAddrs:            []netip.AddrPort{remote2, remote1},
 		Cert:                   crt.Copy(),
 		Cert:                   crt.Copy(),
 		MessageCounter:         0,
 		MessageCounter:         0,
-		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
-		CurrentRelaysToMe:      []iputil.VpnIp{},
-		CurrentRelaysThroughMe: []iputil.VpnIp{},
+		CurrentRemote:          remote1,
+		CurrentRelaysToMe:      []netip.Addr{},
+		CurrentRelaysThroughMe: []netip.Addr{},
 	}
 	}
 
 
 	// Make sure we don't have any unexpected fields
 	// Make sure we don't have any unexpected fields
 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
-	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
+	assert.EqualValues(t, &expectedInfo, thi)
+	//TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here
+	//test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	// Make sure we don't panic if the host info doesn't have a cert yet
 	assert.NotPanics(t, func() {
 	assert.NotPanics(t, func() {
-		thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false)
+		thi = c.GetHostInfoByVpnIp(vpnIp2, false)
 	})
 	})
 }
 }
 
 

+ 20 - 27
control_tester.go

@@ -4,14 +4,13 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"net"
+	"net/netip"
 
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/layers"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
@@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType,
 
 
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 // This is necessary if you did not configure static hosts or are not running a lighthouse
-func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
+func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
 
 
-	iVpnIp := iputil.Ip2VpnIp(vpnIp)
-	if v4 := toAddr.IP.To4(); v4 != nil {
-		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+	if toAddr.Addr().Is4() {
+		remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
 	} else {
 	} else {
-		remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+		remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port()))
 	}
 	}
 }
 }
 
 
 // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
 // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp
 // This is necessary to inform an initiator of possible relays for communicating with a responder
 // This is necessary to inform an initiator of possible relays for communicating with a responder
-func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) {
+func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) {
 	c.f.lightHouse.Lock()
 	c.f.lightHouse.Lock()
-	remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp))
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp)
 	remoteList.Lock()
 	remoteList.Lock()
 	defer remoteList.Unlock()
 	defer remoteList.Unlock()
 	c.f.lightHouse.Unlock()
 	c.f.lightHouse.Unlock()
 
 
-	iVpnIp := iputil.Ip2VpnIp(vpnIp)
-	uVpnIp := []uint32{}
-	for _, rVPnIp := range relayVpnIps {
-		uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp)))
-	}
-
-	remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp)
+	remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps)
 }
 }
 
 
 // GetFromTun will pull a packet off the tun side of nebula
 // GetFromTun will pull a packet off the tun side of nebula
@@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) {
 }
 }
 
 
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
 // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol
-func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) {
+func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) {
+	//TODO: IPV6-WORK
 	ip := layers.IPv4{
 	ip := layers.IPv4{
 		Version:  4,
 		Version:  4,
 		TTL:      64,
 		TTL:      64,
 		Protocol: layers.IPProtocolUDP,
 		Protocol: layers.IPProtocolUDP,
-		SrcIP:    c.f.inside.Cidr().IP,
-		DstIP:    toIp,
+		SrcIP:    c.f.inside.Cidr().Addr().Unmap().AsSlice(),
+		DstIP:    toIp.Unmap().AsSlice(),
 	}
 	}
 
 
 	udp := layers.UDP{
 	udp := layers.UDP{
@@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 	c.f.inside.(*overlay.TestTun).Send(buffer.Bytes())
 }
 }
 
 
-func (c *Control) GetVpnIp() iputil.VpnIp {
-	return c.f.myVpnIp
+func (c *Control) GetVpnIp() netip.Addr {
+	return c.f.myVpnNet.Addr()
 }
 }
 
 
-func (c *Control) GetUDPAddr() string {
-	return c.f.outside.(*udp.TesterConn).Addr.String()
+func (c *Control) GetUDPAddr() netip.AddrPort {
+	return c.f.outside.(*udp.TesterConn).Addr
 }
 }
 
 
-func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool {
+	hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		return false
 		return false
 	}
 	}
@@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
 	return c.f.pki.GetCertState().Certificate
 	return c.f.pki.GetCertState().Certificate
 }
 }
 
 
-func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+func (c *Control) ReHandshake(vpnIp netip.Addr) {
 	c.f.handshakeManager.StartHandshake(vpnIp, nil)
 	c.f.handshakeManager.StartHandshake(vpnIp, nil)
 }
 }

+ 12 - 6
dns_server.go

@@ -3,6 +3,7 @@ package nebula
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -10,7 +11,6 @@ import (
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 // This whole thing should be rewritten to use context
 // This whole thing should be rewritten to use context
@@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string {
 }
 }
 
 
 func (d *dnsRecords) QueryCert(data string) string {
 func (d *dnsRecords) QueryCert(data string) string {
-	ip := net.ParseIP(data[:len(data)-1])
-	if ip == nil {
+	ip, err := netip.ParseAddr(data[:len(data)-1])
+	if err != nil {
 		return ""
 		return ""
 	}
 	}
-	iip := iputil.Ip2VpnIp(ip)
-	hostinfo := d.hostMap.QueryVpnIp(iip)
+
+	hostinfo := d.hostMap.QueryVpnIp(ip)
 	if hostinfo == nil {
 	if hostinfo == nil {
 		return ""
 		return ""
 	}
 	}
+
 	q := hostinfo.GetCert()
 	q := hostinfo.GetCert()
 	if q == nil {
 	if q == nil {
 		return ""
 		return ""
 	}
 	}
+
 	cert := q.Details
 	cert := q.Details
 	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
 	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
 	return c
 	return c
@@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
 			}
 			}
 		case dns.TypeTXT:
 		case dns.TypeTXT:
 			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
 			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
-			b := net.ParseIP(a)
+			b, err := netip.ParseAddr(a)
+			if err != nil {
+				return
+			}
+
 			// We don't answer these queries from non nebula nodes or localhost
 			// We don't answer these queries from non nebula nodes or localhost
 			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
 			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
 			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
 			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {

+ 169 - 169
e2e/handshakes_test.go

@@ -5,7 +5,7 @@ package e2e
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -13,19 +13,18 @@ import (
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"gopkg.in/yaml.v2"
 	"gopkg.in/yaml.v2"
 )
 )
 
 
 func BenchmarkHotPath(b *testing.B) {
 func BenchmarkHotPath(b *testing.B) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
@@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) {
 	r.CancelFlowLogs()
 	r.CancelFlowLogs()
 
 
 	for n := 0; n < b.N; n++ {
 	for n := 0; n < b.N; n++ {
-		myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+		myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
 		_ = r.RouteForAllUntilTxTun(theirControl)
 	}
 	}
 
 
@@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) {
 }
 }
 
 
 func TestGoodHandshake(t *testing.T) {
 func TestGoodHandshake(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.WaitForType(1, 0, theirControl)
 	myControl.WaitForType(1, 0, theirControl)
 
 
 	t.Log("Make sure our host infos are correct")
 	t.Log("Make sure our host infos are correct")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
 
 
 	t.Log("Get that cached packet and make sure it looks right")
 	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 
 	t.Log("Do a bidirectional tunnel test")
 	t.Log("Do a bidirectional tunnel test")
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
 	defer r.RenderFlow()
 	defer r.RenderFlow()
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
 	myControl.Stop()
@@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) {
 }
 }
 
 
 func TestWrongResponderHandshake(t *testing.T) {
 func TestWrongResponderHandshake(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
 
 
 	// The IPs here are chosen on purpose:
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
-	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil)
+	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil)
 
 
 	// Add their real udp addr, which should be tried after evil.
 	// Add their real udp addr, which should be tried after evil.
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl, evilControl)
 	r := router.NewR(t, myControl, theirControl, evilControl)
@@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 	evilControl.Start()
 	evilControl.Start()
 
 
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 		h := &header.H{}
 		h := &header.H{}
 		err := h.Parse(p.Data)
 		err := h.Parse(p.Data)
@@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 			panic(err)
 			panic(err)
 		}
 		}
 
 
-		if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+		if p.To == theirUdpAddr && h.Type == 1 {
 			return router.RouteAndExit
 			return router.RouteAndExit
 		}
 		}
 
 
@@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 
 	t.Log("My cached packet should be received by them")
 	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 
 	t.Log("Test the tunnel with them")
 	t.Log("Test the tunnel with them")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	t.Log("Flush all packets from all controllers")
 	t.Log("Flush all packets from all controllers")
 	r.FlushAll()
 	r.FlushAll()
 
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil")
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 
 	//TODO: assert hostmaps for everyone
 	//TODO: assert hostmaps for everyone
@@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) {
 	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
 	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
 	// But will eventually collapse down to a single tunnel
 	// But will eventually collapse down to a single tunnel
 
 
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake to start on both me and them")
 	t.Log("Trigger a handshake to start on both me and them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 
 	t.Log("Get both stage 1 handshake packets")
 	t.Log("Get both stage 1 handshake packets")
 	myHsForThem := myControl.GetFromUDP(true)
 	myHsForThem := myControl.GetFromUDP(true)
@@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) {
 
 
 	r.Log("Route until they receive a message packet")
 	r.Log("Route until they receive a message packet")
 	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
 	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 
 	r.Log("Their cached packet should be received by me")
 	r.Log("Their cached packet should be received by me")
 	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
 	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
-	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
 
 
 	r.Log("Do a bidirectional tunnel test")
 	r.Log("Do a bidirectional tunnel test")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	myHostmapHosts := myControl.ListHostmapHosts(false)
 	myHostmapHosts := myControl.ListHostmapHosts(false)
 	myHostmapIndexes := myControl.ListHostmapIndexes(false)
 	myHostmapIndexes := myControl.ListHostmapIndexes(false)
@@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) {
 	r.Log("Spin until connection manager tears down a tunnel")
 	r.Log("Spin until connection manager tears down a tunnel")
 
 
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) {
 }
 }
 
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
 func TestUncleanShutdownRaceLoser(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	r.Log("Trigger a handshake from me to them")
 	r.Log("Trigger a handshake from me to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 
 	r.Log("Nuke my hostmap")
 	r.Log("Nuke my hostmap")
 	myHostmap := myControl.GetHostmap()
 	myHostmap := myControl.GetHostmap()
-	myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
 	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 
 
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again"))
 	p = r.RouteForAllUntilTxTun(theirControl)
 	p = r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	r.Log("Wait for the dead index to go away")
 	r.Log("Wait for the dead index to go away")
 	start := len(theirControl.GetHostmap().Indexes)
 	start := len(theirControl.GetHostmap().Indexes)
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		if len(theirControl.GetHostmap().Indexes) < start {
 		if len(theirControl.GetHostmap().Indexes) < start {
 			break
 			break
 		}
 		}
@@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 }
 }
 
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
 func TestUncleanShutdownRaceWinner(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	r.Log("Trigger a handshake from me to them")
 	r.Log("Trigger a handshake from me to them")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 
 
 	r.Log("Nuke my hostmap")
 	r.Log("Nuke my hostmap")
 	theirHostmap := theirControl.GetHostmap()
 	theirHostmap := theirControl.GetHostmap()
-	theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{}
 	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
 
 
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again"))
 	p = r.RouteForAllUntilTxTun(myControl)
 	p = r.RouteForAllUntilTxTun(myControl)
-	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	r.Log("Wait for the dead index to go away")
 	r.Log("Wait for the dead index to go away")
 	start := len(myControl.GetHostmap().Indexes)
 	start := len(myControl.GetHostmap().Indexes)
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		if len(myControl.GetHostmap().Indexes) < start {
 		if len(myControl.GetHostmap().Indexes) < start {
 			break
 			break
 		}
 		}
@@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 }
 }
 
 
 func TestRelays(t *testing.T) {
 func TestRelays(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -361,31 +360,31 @@ func TestRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 }
 }
 
 
 func TestStage1RaceRelays(t *testing.T) {
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
 
 
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
 
 
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	r.Log("Get a tunnel between me and relay")
 	r.Log("Get a tunnel between me and relay")
-	assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
 
 
 	r.Log("Get a tunnel between them and relay")
 	r.Log("Get a tunnel between them and relay")
-	assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
 
 
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 
 	r.Log("Wait for a packet from them to me")
 	r.Log("Wait for a packet from them to me")
 	p := r.RouteForAllUntilTxTun(myControl)
 	p := r.RouteForAllUntilTxTun(myControl)
@@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) {
 
 
 func TestStage1RaceRelays2(t *testing.T) {
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 	l := NewTestLogger()
 	l := NewTestLogger()
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
 
 
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
 
 
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 	r.Log("Get a tunnel between me and relay")
 	r.Log("Get a tunnel between me and relay")
 	l.Info("Get a tunnel between me and relay")
 	l.Info("Get a tunnel between me and relay")
-	assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
 
 
 	r.Log("Get a tunnel between them and relay")
 	r.Log("Get a tunnel between them and relay")
 	l.Info("Get a tunnel between them and relay")
 	l.Info("Get a tunnel between them and relay")
-	assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
 
 
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	r.Log("Trigger a handshake from both them and me via relay to them and me")
 	l.Info("Trigger a handshake from both them and me via relay to them and me")
 	l.Info("Trigger a handshake from both them and me via relay to them and me")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 
 	//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
 	//r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone)
@@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 
 
 	t.Log("Wait until we remove extra tunnels")
 	t.Log("Wait until we remove extra tunnels")
 	l.Info("Wait until we remove extra tunnels")
 	l.Info("Wait until we remove extra tunnels")
@@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 				"theirControl": len(theirControl.GetHostmap().Indexes),
 				"theirControl": len(theirControl.GetHostmap().Indexes),
 				"relayControl": len(relayControl.GetHostmap().Indexes),
 				"relayControl": len(relayControl.GetHostmap().Indexes),
 			}).Info("Waiting for hostinfos to be removed...")
 			}).Info("Waiting for hostinfos to be removed...")
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 		retries--
 		retries--
@@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) {
 
 
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
 	l.Info("Assert the tunnel works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 
 
 	myControl.Stop()
 	myControl.Stop()
 	theirControl.Stop()
 	theirControl.Stop()
@@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) {
 	//
 	//
 	////TODO: assert hostmaps
 	////TODO: assert hostmaps
 }
 }
+
 func TestRehandshakingRelays(t *testing.T) {
 func TestRehandshakingRelays(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 
 
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
-		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between my and relay is updated!")
 			r.Log("Certificate between my and relay is updated!")
@@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
-		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between their and relay is updated!")
 			r.Log("Certificate between their and relay is updated!")
@@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) {
 	}
 	}
 
 
 	r.Log("Assert the relay tunnel still works")
 	r.Log("Assert the relay tunnel still works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	// We should have two hostinfos on all sides
 	// We should have two hostinfos on all sides
 	for len(myControl.GetHostmap().Indexes) != 2 {
 	for len(myControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) {
 
 
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 func TestRehandshakingRelaysPrimary(t *testing.T) {
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
 	// This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", "10.128.0.128/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", "10.128.0.1/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
-	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
 	r.Log("Assert the tunnel works")
 	r.Log("Assert the tunnel works")
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 
 
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
 	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
@@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
-		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between my and relay is updated!")
 			r.Log("Certificate between my and relay is updated!")
@@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 
 
 	for {
 	for {
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
 		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
-		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
-		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			r.Log("Certificate between their and relay is updated!")
 			r.Log("Certificate between their and relay is updated!")
@@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	}
 	}
 
 
 	r.Log("Assert the relay tunnel still works")
 	r.Log("Assert the relay tunnel still works")
-	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
 	// We should have two hostinfos on all sides
 	// We should have two hostinfos on all sides
 	for len(myControl.GetHostmap().Indexes) != 2 {
 	for len(myControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 	for len(theirControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 	for len(relayControl.GetHostmap().Indexes) != 2 {
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
 		r.Log("Assert the relay tunnel still works")
 		r.Log("Assert the relay tunnel still works")
-		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r)
 		r.Log("yupitdoes")
 		r.Log("yupitdoes")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
@@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
 }
 }
 
 
 func TestRehandshaking(t *testing.T) {
 func TestRehandshaking(t *testing.T) {
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Stand up a tunnel between me and them")
 	t.Log("Stand up a tunnel between me and them")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 
@@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) {
 	myConfig.ReloadConfigString(string(rc))
 	myConfig.ReloadConfigString(string(rc))
 
 
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
-		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 		if len(c.Cert.Details.Groups) != 0 {
 		if len(c.Cert.Details.Groups) != 0 {
 			// We have a new certificate now
 			// We have a new certificate now
 			break
 			break
@@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) {
 
 
 	r.Log("Spin until there is only 1 tunnel")
 	r.Log("Spin until there is only 1 tunnel")
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
 
 
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 
 	// Make sure the correct tunnel won
 	// Make sure the correct tunnel won
-	c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 	assert.Contains(t, c.Cert.Details.Groups, "new group")
 	assert.Contains(t, c.Cert.Details.Groups, "new group")
 
 
 	// We should only have a single tunnel now on both sides
 	// We should only have a single tunnel now on both sides
@@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) {
 func TestRehandshakingLoser(t *testing.T) {
 func TestRehandshakingLoser(t *testing.T) {
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
 	// Should be the one with the new certificate
 	// Should be the one with the new certificate
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", "10.128.0.2/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Stand up a tunnel between me and them")
 	t.Log("Stand up a tunnel between me and them")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
-	tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
-	tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+	tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
 	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
 
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
@@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) {
 	theirConfig.ReloadConfigString(string(rc))
 	theirConfig.ReloadConfigString(string(rc))
 
 
 	for {
 	for {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
-		theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+		theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
 
 
 		_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
 		_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
 		if theirNewGroup {
 		if theirNewGroup {
@@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) {
 
 
 	r.Log("Spin until there is only 1 tunnel")
 	r.Log("Spin until there is only 1 tunnel")
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
 	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
-		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 		t.Log("Connection manager hasn't ticked yet")
 		t.Log("Connection manager hasn't ticked yet")
 		time.Sleep(time.Second)
 		time.Sleep(time.Second)
 	}
 	}
 
 
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 
 	// Make sure the correct tunnel won
 	// Make sure the correct tunnel won
-	theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+	theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
 	assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
 	assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
 
 
 	// We should only have a single tunnel now on both sides
 	// We should only have a single tunnel now on both sides
@@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) {
 	// This test forces stage 1, stage 2, stage 1 to be received by me from them
 	// This test forces stage 1, stage 2, stage 1 to be received by me from them
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
 	// caused a cross-linked hostinfo
 	// caused a cross-linked hostinfo
-	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
@@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) {
 	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
 	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
 
 
 	t.Log("Start both handshakes")
 	t.Log("Start both handshakes")
-	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them"))
 
 
 	t.Log("Get both stage 1")
 	t.Log("Get both stage 1")
 	myStage1ForThem := myControl.GetFromUDP(true)
 	myStage1ForThem := myControl.GetFromUDP(true)
@@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) {
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 
 	t.Log("Make sure the tunnel still works")
 	t.Log("Make sure the tunnel still works")
-	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
 	myControl.Stop()
 	myControl.Stop()
 	theirControl.Stop()
 	theirControl.Stop()

+ 15 - 8
e2e/helpers.go

@@ -4,6 +4,7 @@ import (
 	"crypto/rand"
 	"crypto/rand"
 	"io"
 	"io"
 	"net"
 	"net"
+	"net/netip"
 	"time"
 	"time"
 
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
@@ -12,7 +13,7 @@ import (
 )
 )
 
 
 // NewTestCaCert will generate a CA cert
 // NewTestCaCert will generate a CA cert
-func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
 	pub, priv, err := ed25519.GenerateKey(rand.Reader)
 	pub, priv, err := ed25519.GenerateKey(rand.Reader)
 	if before.IsZero() {
 	if before.IsZero() {
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
 		before = time.Now().Add(time.Second * -60).Round(time.Second)
@@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 	}
 	}
 
 
 	if len(ips) > 0 {
 	if len(ips) > 0 {
-		nc.Details.Ips = ips
+		nc.Details.Ips = make([]*net.IPNet, len(ips))
+		for i, ip := range ips {
+			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+		}
 	}
 	}
 
 
 	if len(subnets) > 0 {
 	if len(subnets) > 0 {
-		nc.Details.Subnets = subnets
+		nc.Details.Subnets = make([]*net.IPNet, len(subnets))
+		for i, ip := range subnets {
+			nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
+		}
 	}
 	}
 
 
 	if len(groups) > 0 {
 	if len(groups) > 0 {
@@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []
 
 
 // NewTestCert will generate a signed certificate with the provided details.
 // NewTestCert will generate a signed certificate with the provided details.
 // Expiry times are defaulted if you do not pass them in
 // Expiry times are defaulted if you do not pass them in
-func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
+func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
 	issuer, err := ca.Sha256Sum()
 	issuer, err := ca.Sha256Sum()
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
@@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
 	}
 	}
 
 
 	pub, rawPriv := x25519Keypair()
 	pub, rawPriv := x25519Keypair()
-
+	ipb := ip.Addr().AsSlice()
 	nc := &cert.NebulaCertificate{
 	nc := &cert.NebulaCertificate{
 		Details: cert.NebulaCertificateDetails{
 		Details: cert.NebulaCertificateDetails{
-			Name:           name,
-			Ips:            []*net.IPNet{ip},
-			Subnets:        subnets,
+			Name: name,
+			Ips:  []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
+			//Subnets:        subnets,
 			Groups:         groups,
 			Groups:         groups,
 			NotBefore:      time.Unix(before.Unix(), 0),
 			NotBefore:      time.Unix(before.Unix(), 0),
 			NotAfter:       time.Unix(after.Unix(), 0),
 			NotAfter:       time.Unix(after.Unix(), 0),

+ 28 - 24
e2e/helpers_test.go

@@ -6,7 +6,7 @@ package e2e
 import (
 import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"testing"
 	"testing"
 	"time"
 	"time"
@@ -19,7 +19,6 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"gopkg.in/yaml.v2"
 	"gopkg.in/yaml.v2"
 )
 )
@@ -27,15 +26,23 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 // newSimpleServer creates a nebula instance with many assumptions
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
 	l := NewTestLogger()
 	l := NewTestLogger()
 
 
-	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
-	copy(vpnIpNet.IP, udpIp)
-	vpnIpNet.IP[1] += 128
-	udpAddr := net.UDPAddr{
-		IP:   udpIp,
-		Port: 4242,
+	vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
+	if err != nil {
+		panic(err)
+	}
+
+	var udpAddr netip.AddrPort
+	if vpnIpNet.Addr().Is4() {
+		budpIp := vpnIpNet.Addr().As4()
+		budpIp[1] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242)
+	} else {
+		budpIp := vpnIpNet.Addr().As16()
+		budpIp[13] -= 128
+		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
 	}
 	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 	_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
 
 
@@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		//	"try_interval": "1s",
 		//	"try_interval": "1s",
 		//},
 		//},
 		"listen": m{
 		"listen": m{
-			"host": udpAddr.IP.String(),
-			"port": udpAddr.Port,
+			"host": udpAddr.Addr().String(),
+			"port": udpAddr.Port(),
 		},
 		},
 		"logging": m{
 		"logging": m{
 			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
 			"timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name),
@@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	return control, vpnIpNet, &udpAddr, c
+	return control, vpnIpNet, udpAddr, c
 }
 }
 
 
 type doneCb func()
 type doneCb func()
@@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 	}
 	}
 }
 }
 
 
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	// Send a packet from them to me
 	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
 	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
 	bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 }
 
 
-func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) {
+func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) {
 	// Get both host infos
 	// Get both host infos
-	hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false)
+	hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false)
 	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
 	assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA")
 
 
-	hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false)
+	hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false)
 	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
 	assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB")
 
 
 	// Check that both vpn and real addr are correct
 	// Check that both vpn and real addr are correct
 	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
 	assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A")
 	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
 	assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B")
 
 
-	assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A")
-	assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B")
-
-	assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A")
-	assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B")
+	assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A")
+	assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B")
 
 
 	// Check that our indexes match
 	// Check that our indexes match
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
 	assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index")
@@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB
 	//checkIndexes("hmB", hmB, hAinB)
 	//checkIndexes("hmB", hmB, hAinB)
 }
 }
 
 
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) {
+func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")
 	assert.NotNil(t, v4, "No ipv4 data found")
 
 
-	assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect")
-	assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect")
+	assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect")
+	assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect")
 
 
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
 	assert.NotNil(t, udp, "No udp data found")
 	assert.NotNil(t, udp, "No udp data found")

+ 4 - 4
e2e/router/hostmap.go

@@ -5,11 +5,11 @@ package router
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"net/netip"
 	"sort"
 	"sort"
 	"strings"
 	"strings"
 
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type edge struct {
 type edge struct {
@@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	return r, globalLines
 	return r, globalLines
 }
 }
 
 
-func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
-	keys := make([]iputil.VpnIp, 0, len(hosts))
+func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr {
+	keys := make([]netip.Addr, 0, len(hosts))
 	for key := range hosts {
 	for key := range hosts {
 		keys = append(keys, key)
 		keys = append(keys, key)
 	}
 	}
 
 
 	sort.SliceStable(keys, func(i, j int) bool {
 	sort.SliceStable(keys, func(i, j int) bool {
-		return keys[i] > keys[j]
+		return keys[i].Compare(keys[j]) > 0
 	})
 	})
 
 
 	return keys
 	return keys

+ 36 - 63
e2e/router/router.go

@@ -6,12 +6,11 @@ package router
 import (
 import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"reflect"
 	"reflect"
 	"sort"
 	"sort"
-	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"testing"
 	"testing"
@@ -21,7 +20,6 @@ import (
 	"github.com/google/gopacket/layers"
 	"github.com/google/gopacket/layers"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"golang.org/x/exp/maps"
 	"golang.org/x/exp/maps"
 )
 )
@@ -29,18 +27,18 @@ import (
 type R struct {
 type R struct {
 	// Simple map of the ip:port registered on a control to the control
 	// Simple map of the ip:port registered on a control to the control
 	// Basically a router, right?
 	// Basically a router, right?
-	controls map[string]*nebula.Control
+	controls map[netip.AddrPort]*nebula.Control
 
 
 	// A map for inbound packets for a control that doesn't know about this address
 	// A map for inbound packets for a control that doesn't know about this address
-	inNat map[string]*nebula.Control
+	inNat map[netip.AddrPort]*nebula.Control
 
 
 	// A last used map, if an inbound packet hit the inNat map then
 	// A last used map, if an inbound packet hit the inNat map then
 	// all return packets should use the same last used inbound address for the outbound sender
 	// all return packets should use the same last used inbound address for the outbound sender
 	// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
 	// map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver
-	outNat map[string]net.UDPAddr
+	outNat map[string]netip.AddrPort
 
 
 	// A map of vpn ip to the nebula control it belongs to
 	// A map of vpn ip to the nebula control it belongs to
-	vpnControls map[iputil.VpnIp]*nebula.Control
+	vpnControls map[netip.Addr]*nebula.Control
 
 
 	ignoreFlows []ignoreFlow
 	ignoreFlows []ignoreFlow
 	flow        []flowEntry
 	flow        []flowEntry
@@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	}
 	}
 
 
 	r := &R{
 	r := &R{
-		controls:     make(map[string]*nebula.Control),
-		vpnControls:  make(map[iputil.VpnIp]*nebula.Control),
-		inNat:        make(map[string]*nebula.Control),
-		outNat:       make(map[string]net.UDPAddr),
+		controls:     make(map[netip.AddrPort]*nebula.Control),
+		vpnControls:  make(map[netip.Addr]*nebula.Control),
+		inNat:        make(map[netip.AddrPort]*nebula.Control),
+		outNat:       make(map[string]netip.AddrPort),
 		flow:         []flowEntry{},
 		flow:         []flowEntry{},
 		ignoreFlows:  []ignoreFlow{},
 		ignoreFlows:  []ignoreFlow{},
 		fn:           filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
 		fn:           filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())),
@@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 	for _, c := range controls {
 	for _, c := range controls {
 		addr := c.GetUDPAddr()
 		addr := c.GetUDPAddr()
 		if _, ok := r.controls[addr]; ok {
 		if _, ok := r.controls[addr]; ok {
-			panic("Duplicate listen address: " + addr)
+			panic("Duplicate listen address: " + addr.String())
 		}
 		}
 
 
 		r.vpnControls[c.GetVpnIp()] = c
 		r.vpnControls[c.GetVpnIp()] = c
@@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 // It does not look at the addr attached to the instance.
 // It does not look at the addr attached to the instance.
 // If a route is used, this will behave like a NAT for the return path.
 // If a route is used, this will behave like a NAT for the return path.
 // Rewriting the source ip:port to what was last sent to from the origin
 // Rewriting the source ip:port to what was last sent to from the origin
-func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
+func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) {
 	r.Lock()
 	r.Lock()
 	defer r.Unlock()
 	defer r.Unlock()
 
 
-	inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))
+	inAddr := netip.AddrPortFrom(ip, port)
 	if _, ok := r.inNat[inAddr]; ok {
 	if _, ok := r.inNat[inAddr]; ok {
-		panic("Duplicate listen address inNat: " + inAddr)
+		panic("Duplicate listen address inNat: " + inAddr.String())
 	}
 	}
 	r.inNat[inAddr] = c
 	r.inNat[inAddr] = c
 }
 }
@@ -198,7 +196,7 @@ func (r *R) renderFlow() {
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	var participants = map[string]struct{}{}
+	var participants = map[netip.AddrPort]struct{}{}
 	var participantsVals []string
 	var participantsVals []string
 
 
 	fmt.Fprintln(f, "```mermaid")
 	fmt.Fprintln(f, "```mermaid")
@@ -215,7 +213,7 @@ func (r *R) renderFlow() {
 			continue
 			continue
 		}
 		}
 		participants[addr] = struct{}{}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr, ":", "-", 1)
+		sanAddr := strings.Replace(addr.String(), ":", "-", 1)
 		participantsVals = append(participantsVals, sanAddr)
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
@@ -252,9 +250,9 @@ func (r *R) renderFlow() {
 
 
 			fmt.Fprintf(f,
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
+				strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1),
 				line,
 				line,
-				strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+				strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 			)
 		}
 		}
@@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 func (r *R) renderHostmaps(title string) {
 func (r *R) renderHostmaps(title string) {
 	c := maps.Values(r.controls)
 	c := maps.Values(r.controls)
 	sort.SliceStable(c, func(i, j int) bool {
 	sort.SliceStable(c, func(i, j int) bool {
-		return c[i].GetVpnIp() > c[j].GetVpnIp()
+		return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0
 	})
 	})
 
 
 	s := renderHostmaps(c...)
 	s := renderHostmaps(c...)
@@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 
 
 		// Nope, lets push the sender along
 		// Nope, lets push the sender along
 		case p := <-udpTx:
 		case p := <-udpTx:
-			outAddr := sender.GetUDPAddr()
 			r.Lock()
 			r.Lock()
-			inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-			c := r.getControl(outAddr, inAddr, p)
+			c := r.getControl(sender.GetUDPAddr(), p.To, p)
 			if c == nil {
 			if c == nil {
 				r.Unlock()
 				r.Unlock()
 				panic("No control for udp tx")
 				panic("No control for udp tx")
@@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte {
 		} else {
 		} else {
 			// we are a udp tx, route and continue
 			// we are a udp tx, route and continue
 			p := rx.Interface().(*udp.Packet)
 			p := rx.Interface().(*udp.Packet)
-			outAddr := cm[x].GetUDPAddr()
-
-			inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-			c := r.getControl(outAddr, inAddr, p)
+			c := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 			if c == nil {
 			if c == nil {
 				r.Unlock()
 				r.Unlock()
 				panic("No control for udp tx")
 				panic("No control for udp tx")
@@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 			panic(err)
 			panic(err)
 		}
 		}
 
 
-		outAddr := sender.GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(sender.GetUDPAddr(), p.To, p)
 		if receiver == nil {
 		if receiver == nil {
 			r.Unlock()
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't RouteExitFunc for host: " + p.To.String())
 		}
 		}
 
 
 		e := whatDo(p, receiver)
 		e := whatDo(p, receiver)
@@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // If the router doesn't have the nebula controller for that address, we panic
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) {
 	if finish == KeepRouting {
 	if finish == KeepRouting {
 		finish = RouteAndExit
 		finish = RouteAndExit
 	}
 	}
 
 
 	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
 	r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType {
-		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
+		if p.To == toAddr {
 			return finish
 			return finish
 		}
 		}
 
 
@@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
 		r.Lock()
 		r.Lock()
 
 
 		p := rx.Interface().(*udp.Packet)
 		p := rx.Interface().(*udp.Packet)
-
-		outAddr := cm[x].GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 		if receiver == nil {
 		if receiver == nil {
 			r.Unlock()
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't RouteForAllExitFunc for host: " + p.To.String())
 		}
 		}
 
 
 		e := whatDo(p, receiver)
 		e := whatDo(p, receiver)
@@ -697,12 +685,10 @@ func (r *R) FlushAll() {
 
 
 		p := rx.Interface().(*udp.Packet)
 		p := rx.Interface().(*udp.Packet)
 
 
-		outAddr := cm[x].GetUDPAddr()
-		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
-		receiver := r.getControl(outAddr, inAddr, p)
+		receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p)
 		if receiver == nil {
 		if receiver == nil {
 			r.Unlock()
 			r.Unlock()
-			panic("Can't route for host: " + inAddr)
+			panic("Can't FlushAll for host: " + p.To.String())
 		}
 		}
 		r.Unlock()
 		r.Unlock()
 	}
 	}
@@ -710,28 +696,14 @@ func (r *R) FlushAll() {
 
 
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
 // This is an internal router function, the caller must hold the lock
-func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control {
-	if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok {
-		p.FromIp = newAddr.IP
-		p.FromPort = uint16(newAddr.Port)
+func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control {
+	if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok {
+		p.From = newAddr
 	}
 	}
 
 
 	c, ok := r.inNat[toAddr]
 	c, ok := r.inNat[toAddr]
 	if ok {
 	if ok {
-		sHost, sPort, err := net.SplitHostPort(toAddr)
-		if err != nil {
-			panic(err)
-		}
-
-		port, err := strconv.Atoi(sPort)
-		if err != nil {
-			panic(err)
-		}
-
-		r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{
-			IP:   net.ParseIP(sHost),
-			Port: port,
-		}
+		r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr
 		return c
 		return c
 	}
 	}
 
 
@@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string {
 	}
 	}
 
 
 	from := "unknown"
 	from := "unknown"
-	if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok {
-		from = c.GetUDPAddr()
+	srcAddr, _ := netip.AddrFromSlice(v4.SrcIP)
+	if c, ok := r.vpnControls[srcAddr]; ok {
+		from = c.GetUDPAddr().String()
 	}
 	}
 
 
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
 	udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP)
@@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string {
 	return fmt.Sprintf(
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
 		strings.Replace(from, ":", "-", 1),
 		strings.Replace(from, ":", "-", 1),
-		strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
+		strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1),
 		udp.SrcPort,
 		udp.SrcPort,
 		udp.DstPort,
 		udp.DstPort,
 		string(data.Payload()),
 		string(data.Payload()),

+ 58 - 42
firewall.go

@@ -6,23 +6,23 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"hash/fnv"
 	"hash/fnv"
-	"net"
+	"net/netip"
 	"reflect"
 	"reflect"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 )
 )
 
 
 type FirewallInterface interface {
 type FirewallInterface interface {
-	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
+	AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
 }
 }
 
 
 type conn struct {
 type conn struct {
@@ -52,8 +52,8 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 	DefaultTimeout time.Duration //linux: 600s
 
 
 	// Used to ensure we don't emit local packets for ips we don't own
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps     *cidr.Tree4[struct{}]
-	assignedCIDR *net.IPNet
+	localIps     *bart.Table[struct{}]
+	assignedCIDR netip.Prefix
 	hasSubnets   bool
 	hasSubnets   bool
 
 
 	rules        string
 	rules        string
@@ -108,7 +108,7 @@ type FirewallRule struct {
 	Any    *firewallLocalCIDR
 	Any    *firewallLocalCIDR
 	Hosts  map[string]*firewallLocalCIDR
 	Hosts  map[string]*firewallLocalCIDR
 	Groups []*firewallGroups
 	Groups []*firewallGroups
-	CIDR   *cidr.Tree4[*firewallLocalCIDR]
+	CIDR   *bart.Table[*firewallLocalCIDR]
 }
 }
 
 
 type firewallGroups struct {
 type firewallGroups struct {
@@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
 
 
 type firewallLocalCIDR struct {
 type firewallLocalCIDR struct {
 	Any       bool
 	Any       bool
-	LocalCIDR *cidr.Tree4[struct{}]
+	LocalCIDR *bart.Table[struct{}]
 }
 }
 
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 		max = defaultTimeout
 	}
 	}
 
 
-	localIps := cidr.NewTree4[struct{}]()
-	var assignedCIDR *net.IPNet
+	localIps := new(bart.Table[struct{}])
+	var assignedCIDR netip.Prefix
+	var assignedSet bool
 	for _, ip := range c.Details.Ips {
 	for _, ip := range c.Details.Ips {
-		ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
-		localIps.AddCIDR(ipNet, struct{}{})
+		//TODO: IPV6-WORK the unmap is a bit unfortunate
+		nip, _ := netip.AddrFromSlice(ip.IP)
+		nip = nip.Unmap()
+		nprefix := netip.PrefixFrom(nip, nip.BitLen())
+		localIps.Insert(nprefix, struct{}{})
 
 
-		if assignedCIDR == nil {
+		if !assignedSet {
 			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
 			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
-			assignedCIDR = ipNet
+			assignedCIDR = nprefix
+			assignedSet = true
 		}
 		}
 	}
 	}
 
 
 	for _, n := range c.Details.Subnets {
 	for _, n := range c.Details.Subnets {
-		localIps.AddCIDR(n, struct{}{})
+		nip, _ := netip.AddrFromSlice(n.IP)
+		ones, _ := n.Mask.Size()
+		nip = nip.Unmap()
+		localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
 	}
 	}
 
 
 	return &Firewall{
 	return &Firewall{
@@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
 }
 }
 
 
 // AddRule properly creates the in memory rule structure for a firewall table.
 // AddRule properly creates the in memory rule structure for a firewall table.
-func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
 	// https://github.com/golang/go/issues/14131
 	// https://github.com/golang/go/issues/14131
 	sIp := ""
 	sIp := ""
-	if ip != nil {
+	if ip.IsValid() {
 		sIp = ip.String()
 		sIp = ip.String()
 	}
 	}
 	lIp := ""
 	lIp := ""
-	if localIp != nil {
+	if localIp.IsValid() {
 		lIp = localIp.String()
 		lIp = localIp.String()
 	}
 	}
 
 
@@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 			return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
 		}
 		}
 
 
-		var cidr *net.IPNet
+		var cidr netip.Prefix
 		if r.Cidr != "" {
 		if r.Cidr != "" {
-			_, cidr, err = net.ParseCIDR(r.Cidr)
+			cidr, err = netip.ParsePrefix(r.Cidr)
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
 				return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
 			}
 			}
 		}
 		}
 
 
-		var localCidr *net.IPNet
+		var localCidr netip.Prefix
 		if r.LocalCidr != "" {
 		if r.LocalCidr != "" {
-			_, localCidr, err = net.ParseCIDR(r.LocalCidr)
+			localCidr, err = netip.ParsePrefix(r.LocalCidr)
 			if err != nil {
 			if err != nil {
 				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
 				return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
 			}
 			}
@@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 
 
 	// Make sure remote address matches nebula certificate
 	// Make sure remote address matches nebula certificate
 	if remoteCidr := h.remoteCidr; remoteCidr != nil {
 	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		ok, _ := remoteCidr.Contains(fp.RemoteIP)
+		//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+		_, ok := remoteCidr.Lookup(fp.RemoteIP)
 		if !ok {
 		if !ok {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
 			return ErrInvalidRemoteIP
@@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 	}
 	}
 
 
 	// Make sure we are supposed to be handling this local ip address
 	// Make sure we are supposed to be handling this local ip address
-	ok, _ := f.localIps.Contains(fp.LocalIP)
+	//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
+	_, ok := f.localIps.Lookup(fp.LocalIP)
 	if !ok {
 	if !ok {
 		f.metrics(incoming).droppedLocalIP.Inc(1)
 		f.metrics(incoming).droppedLocalIP.Inc(1)
 		return ErrInvalidLocalIP
 		return ErrInvalidLocalIP
@@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Caller must own the connMutex lock!
 // Caller must own the connMutex lock!
 func (f *Firewall) evict(p firewall.Packet) {
 func (f *Firewall) evict(p firewall.Packet) {
-	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
 	// Are we still tracking this conn?
 	conntrack := f.Conntrack
 	conntrack := f.Conntrack
 	t, ok := conntrack.Conns[p]
 	t, ok := conntrack.Conns[p]
@@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 	return false
 	return false
 }
 }
 
 
-func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
 	if startPort > endPort {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 		return fmt.Errorf("start port was lower than end port")
 	}
 	}
@@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 }
 
 
-func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
 	fr := func() *FirewallRule {
 	fr := func() *FirewallRule {
 		return &FirewallRule{
 		return &FirewallRule{
 			Hosts:  make(map[string]*firewallLocalCIDR),
 			Hosts:  make(map[string]*firewallLocalCIDR),
 			Groups: make([]*firewallGroups, 0),
 			Groups: make([]*firewallGroups, 0),
-			CIDR:   cidr.NewTree4[*firewallLocalCIDR](),
+			CIDR:   new(bart.Table[*firewallLocalCIDR]),
 		}
 		}
 	}
 	}
 
 
@@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 	return fc.CANames[s.Details.Name].match(p, c)
 	return fc.CANames[s.Details.Name].match(p, c)
 }
 }
 
 
-func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
 	flc := func() *firewallLocalCIDR {
 	flc := func() *firewallLocalCIDR {
 		return &firewallLocalCIDR{
 		return &firewallLocalCIDR{
-			LocalCIDR: cidr.NewTree4[struct{}](),
+			LocalCIDR: new(bart.Table[struct{}]),
 		}
 		}
 	}
 	}
 
 
@@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
 		fr.Hosts[host] = nlc
 		fr.Hosts[host] = nlc
 	}
 	}
 
 
-	if ip != nil {
-		_, nlc := fr.CIDR.GetCIDR(ip)
+	if ip.IsValid() {
+		nlc, _ := fr.CIDR.Get(ip)
 		if nlc == nil {
 		if nlc == nil {
 			nlc = flc()
 			nlc = flc()
 		}
 		}
@@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		fr.CIDR.AddCIDR(ip, nlc)
+		fr.CIDR.Insert(ip, nlc)
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
-	if len(groups) == 0 && host == "" && ip == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
+	if len(groups) == 0 && host == "" && !ip.IsValid() {
 		return true
 		return true
 	}
 	}
 
 
@@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 		return true
 		return true
 	}
 	}
 
 
-	if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
+	if ip.IsValid() && ip.Bits() == 0 {
 		return true
 		return true
 	}
 	}
 
 
@@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		}
 		}
 	}
 	}
 
 
-	return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
-		return flc.match(p, c)
+	matched := false
+	prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
+	fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
+		if prefix.Contains(p.RemoteIP) && val.match(p, c) {
+			matched = true
+			return false
+		}
+		return true
 	})
 	})
+	return matched
 }
 }
 
 
-func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
-	if localIp == nil {
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
+	if !localIp.IsValid() {
 		if !f.hasSubnets || f.defaultLocalCIDRAny {
 		if !f.hasSubnets || f.defaultLocalCIDRAny {
 			flc.Any = true
 			flc.Any = true
 			return nil
 			return nil
 		}
 		}
 
 
 		localIp = f.assignedCIDR
 		localIp = f.assignedCIDR
-	} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
+	} else if localIp.Bits() == 0 {
 		flc.Any = true
 		flc.Any = true
 	}
 	}
 
 
-	flc.LocalCIDR.AddCIDR(localIp, struct{}{})
+	flc.LocalCIDR.Insert(localIp, struct{}{})
 	return nil
 	return nil
 }
 }
 
 
@@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
 		return true
 		return true
 	}
 	}
 
 
-	ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
+	_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
 	return ok
 	return ok
 }
 }
 
 

+ 3 - 4
firewall/packet.go

@@ -3,8 +3,7 @@ package firewall
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 )
 
 
 type m map[string]interface{}
 type m map[string]interface{}
@@ -20,8 +19,8 @@ const (
 )
 )
 
 
 type Packet struct {
 type Packet struct {
-	LocalIP    iputil.VpnIp
-	RemoteIP   iputil.VpnIp
+	LocalIP    netip.Addr
+	RemoteIP   netip.Addr
 	LocalPort  uint16
 	LocalPort  uint16
 	RemotePort uint16
 	RemotePort uint16
 	Protocol   uint8
 	Protocol   uint8

+ 74 - 73
firewall_test.go

@@ -5,13 +5,13 @@ import (
 	"errors"
 	"errors"
 	"math"
 	"math"
 	"net"
 	"net"
+	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
@@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
 	assert.NotNil(t, fw.OutRules)
 
 
-	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
+	ti, err := netip.ParsePrefix("1.2.3.4/32")
+	assert.NoError(t, err)
 
 
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	// An empty rule is any
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.InRules.UDP[1].Any.Any)
 	assert.Nil(t, fw.InRules.UDP[1].Any.Any)
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
 	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
 	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
-	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
+	_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
 	assert.True(t, ok)
 	assert.True(t, ok)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
-	ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
+	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
 	assert.True(t, ok)
 	assert.True(t, ok)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
+	anyIp, err := netip.ParsePrefix("0.0.0.0/0")
+	assert.NoError(t, err)
+
+	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 
 	// Test error conditions
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
-	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
+	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 }
 }
 
 
 func TestFirewall_Drop(t *testing.T) {
 func TestFirewall_Drop(t *testing.T) {
@@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr("1.2.3.4"),
 	}
 	}
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
-	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
+	assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
 	// Allow inbound
 	// Allow inbound
 	resetConntrack(fw)
 	resetConntrack(fw)
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
@@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) {
 
 
 	// test remote mismatch
 	// test remote mismatch
 	oldRemote := p.RemoteIP
 	oldRemote := p.RemoteIP
-	p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
+	p.RemoteIP = netip.MustParseAddr("1.2.3.10")
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	p.RemoteIP = oldRemote
 	p.RemoteIP = oldRemote
 
 
 	// ensure signer doesn't get in the way of group checks
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caSha doesn't drop on match
 	// test caSha doesn't drop on match
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 
 
 	// ensure ca name doesn't get in the way of group checks
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
 
 
 	// test caName doesn't drop on match
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 	assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 }
 
 
@@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		TCP: firewallPort{},
 		TCP: firewallPort{},
 	}
 	}
 
 
-	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
-	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
-	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
+	pfix := netip.MustParsePrefix("172.1.1.1/32")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	b.Run("fail on proto", func(b *testing.B) {
 	b.Run("fail on proto", func(b *testing.B) {
@@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 
 
 	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
 	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
 		c := &cert.NebulaCertificate{}
 		c := &cert.NebulaCertificate{}
-		ip, _, _ := net.ParseCIDR("9.254.254.254/32")
-		lip := iputil.Ip2VpnIp(ip)
+		ip := netip.MustParsePrefix("9.254.254.254/32")
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
@@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
@@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 			},
 		}
 		}
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp))
 		}
 		}
 	})
 	})
 
 
@@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	}
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
@@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) {
 	h1.CreateRemoteCIDR(&c1)
 	h1.CreateRemoteCIDR(&c1)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// h1/c1 lacks the proper groups
 	// h1/c1 lacks the proper groups
@@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  1,
 		LocalPort:  1,
 		RemotePort: 1,
 		RemotePort: 1,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 			peerCert: &c1,
 		},
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	}
 	h1.CreateRemoteCIDR(&c1)
 	h1.CreateRemoteCIDR(&c1)
 
 
@@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c2,
 			peerCert: &c2,
 		},
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	}
 	h2.CreateRemoteCIDR(&c2)
 	h2.CreateRemoteCIDR(&c2)
 
 
@@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c3,
 			peerCert: &c3,
 		},
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	}
 	h3.CreateRemoteCIDR(&c3)
 	h3.CreateRemoteCIDR(&c3)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// c1 should pass because host match
 	// c1 should pass because host match
@@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	l.SetOutput(ob)
 	l.SetOutput(ob)
 
 
 	p := firewall.Packet{
 	p := firewall.Packet{
-		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
-		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
+		LocalIP:    netip.MustParseAddr("1.2.3.4"),
+		RemoteIP:   netip.MustParseAddr("1.2.3.4"),
 		LocalPort:  10,
 		LocalPort:  10,
 		RemotePort: 90,
 		RemotePort: 90,
 		Protocol:   firewall.ProtoUDP,
 		Protocol:   firewall.ProtoUDP,
@@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: &c,
 			peerCert: &c,
 		},
 		},
-		vpnIp: iputil.Ip2VpnIp(ipNet.IP),
+		vpnIp: netip.MustParseAddr(ipNet.IP.String()),
 	}
 	}
 	h.CreateRemoteCIDR(&c)
 	h.CreateRemoteCIDR(&c)
 
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	cp := cert.NewCAPool()
 	cp := cert.NewCAPool()
 
 
 	// Drop outbound
 	// Drop outbound
@@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 
 	oldFw := fw
 	oldFw := fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 
@@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 
 
 	oldFw = fw
 	oldFw = fw
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	fw.Conntrack = oldFw.Conntrack
 	fw.Conntrack = oldFw.Conntrack
 	fw.rulesVersion = oldFw.rulesVersion + 1
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 
@@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) {
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
+	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 
 	// Test local_cidr parse error
 	// Test local_cidr parse error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
 	_, err = NewFirewallFromConfig(l, c, conf)
 	_, err = NewFirewallFromConfig(l, c, conf)
-	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
+	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
 
 
 	// Test both group and groups
 	// Test both group and groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
@@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	mf := &mockFirewall{}
 	mf := &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test adding udp rule
 	// Test adding udp rule
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test adding icmp rule
 	// Test adding icmp rule
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test adding any rule
 	// Test adding any rule
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test adding rule with cidr
 	// Test adding rule with cidr
-	cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
+	cidr := netip.MustParsePrefix("10.0.0.0/8")
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test adding rule with local_cidr
 	// Test adding rule with local_cidr
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
 
 
 	// Test adding rule with ca_sha
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
 
 
 	// Test adding rule with ca_name
 	// Test adding rule with ca_name
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
 
 
 	// Test single group
 	// Test single group
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test single groups
 	// Test single groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test multiple AND groups
 	// Test multiple AND groups
 	conf = config.NewC(l)
 	conf = config.NewC(l)
 	mf = &mockFirewall{}
 	mf = &mockFirewall{}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
-	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
 
 
 	// Test Add error
 	// Test Add error
 	conf = config.NewC(l)
 	conf = config.NewC(l)
@@ -871,8 +872,8 @@ type addRuleCall struct {
 	endPort   int32
 	endPort   int32
 	groups    []string
 	groups    []string
 	host      string
 	host      string
-	ip        *net.IPNet
-	localIp   *net.IPNet
+	ip        netip.Prefix
+	localIp   netip.Prefix
 	caName    string
 	caName    string
 	caSha     string
 	caSha     string
 }
 }
@@ -882,7 +883,7 @@ type mockFirewall struct {
 	nextCallReturn error
 	nextCallReturn error
 }
 }
 
 
-func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
 	mf.lastCall = addRuleCall{
 	mf.lastCall = addRuleCall{
 		incoming:  incoming,
 		incoming:  incoming,
 		proto:     proto,
 		proto:     proto,

+ 2 - 0
go.mod

@@ -38,8 +38,10 @@ require (
 
 
 require (
 require (
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
+	github.com/bits-and-blooms/bitset v1.13.0 // indirect
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
+	github.com/gaissmai/bart v0.11.1 // indirect
 	github.com/google/btree v1.1.2 // indirect
 	github.com/google/btree v1.1.2 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/prometheus/client_model v0.5.0 // indirect
 	github.com/prometheus/client_model v0.5.0 // indirect

+ 6 - 0
go.sum

@@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
+github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE=
+github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
 github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
 github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -24,6 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
 github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
+github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI=
+github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
+github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc=
+github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=

+ 42 - 16
handshake_ix.go

@@ -1,13 +1,12 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"net/netip"
 	"time"
 	"time"
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // NOISE IX Handshakes
 // NOISE IX Handshakes
@@ -63,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	return true
 	return true
 }
 }
 
 
-func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	certState := f.pki.GetCertState()
 	certState := f.pki.GetCertState()
 	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
 	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	// Mark packet 1 as seen so it doesn't show up as missed
@@ -99,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		e.Info("Invalid certificate from host")
 		e.Info("Invalid certificate from host")
 		return
 		return
 	}
 	}
-	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+
+	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+	if !ok {
+		e := f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
+
+		if f.l.Level > logrus.DebugLevel {
+			e = e.WithField("cert", remoteCert)
+		}
+
+		e.Info("Invalid vpn ip from host")
+		return
+	}
+
+	vpnIp = vpnIp.Unmap()
 	certName := remoteCert.Details.Name
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 	issuer := remoteCert.Details.Issuer
 
 
-	if vpnIp == f.myVpnIp {
+	if vpnIp == f.myVpnNet.Addr() {
 		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			WithField("certName", certName).
 			WithField("certName", certName).
 			WithField("fingerprint", fingerprint).
 			WithField("fingerprint", fingerprint).
@@ -113,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		return
 		return
 	}
 	}
 
 
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) {
+	if addr.IsValid() {
+		if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) {
 			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 			return
 		}
 		}
@@ -138,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		HandshakePacket:   make(map[uint8][]byte, 0),
 		lastHandshakeTime: hs.Details.Time,
 		lastHandshakeTime: hs.Details.Time,
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 		},
 	}
 	}
@@ -218,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 
 
 			msg = existing.HandshakePacket[2]
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-			if addr != nil {
+			if addr.IsValid() {
 				err := f.outside.WriteTo(msg, addr)
 				err := f.outside.WriteTo(msg, addr)
 				if err != nil {
 				if err != nil {
 					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
 					f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr).
@@ -284,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 
 
 	// Do the send
 	// Do the send
 	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
 	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-	if addr != nil {
+	if addr.IsValid() {
 		err = f.outside.WriteTo(msg, addr)
 		err = f.outside.WriteTo(msg, addr)
 		if err != nil {
 		if err != nil {
 			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
@@ -326,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	return
 	return
 }
 }
 
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
 	if hh == nil {
 	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
 		return true
@@ -336,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	defer hh.Unlock()
 	defer hh.Unlock()
 
 
 	hostinfo := hh.hostinfo
 	hostinfo := hh.hostinfo
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
+	if addr.IsValid() {
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 			return false
 		}
 		}
@@ -389,7 +402,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 		return true
 		return true
 	}
 	}
 
 
-	vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP)
+	vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP)
+	if !ok {
+		e := f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
+
+		if f.l.Level > logrus.DebugLevel {
+			e = e.WithField("cert", remoteCert)
+		}
+
+		e.Info("Invalid vpn ip from host")
+		return true
+	}
+
+	vpnIp = vpnIp.Unmap()
 	certName := remoteCert.Details.Name
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
 	issuer := remoteCert.Details.Issuer
 	issuer := remoteCert.Details.Issuer
@@ -453,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
 	ci.eKey = NewNebulaCipherState(eKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 
 
 	// Make sure the current udpAddr being used is set for responding
 	// Make sure the current udpAddr being used is set for responding
-	if addr != nil {
+	if addr.IsValid() {
 		hostinfo.SetRemote(addr)
 		hostinfo.SetRemote(addr)
 	} else {
 	} else {
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp)

+ 50 - 41
handshake_manager.go

@@ -6,15 +6,15 @@ import (
 	"crypto/rand"
 	"crypto/rand"
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
-	"net"
+	"net/netip"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
+	"golang.org/x/exp/slices"
 )
 )
 
 
 const (
 const (
@@ -46,14 +46,14 @@ type HandshakeManager struct {
 	// Mutex for interacting with the vpnIps and indexes maps
 	// Mutex for interacting with the vpnIps and indexes maps
 	sync.RWMutex
 	sync.RWMutex
 
 
-	vpnIps  map[iputil.VpnIp]*HandshakeHostInfo
+	vpnIps  map[netip.Addr]*HandshakeHostInfo
 	indexes map[uint32]*HandshakeHostInfo
 	indexes map[uint32]*HandshakeHostInfo
 
 
 	mainHostMap            *HostMap
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
 	lightHouse             *LightHouse
 	outside                udp.Conn
 	outside                udp.Conn
 	config                 HandshakeConfig
 	config                 HandshakeConfig
-	OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
+	OutboundHandshakeTimer *LockingTimerWheel[netip.Addr]
 	messageMetrics         *MessageMetrics
 	messageMetrics         *MessageMetrics
 	metricInitiated        metrics.Counter
 	metricInitiated        metrics.Counter
 	metricTimedOut         metrics.Counter
 	metricTimedOut         metrics.Counter
@@ -61,17 +61,17 @@ type HandshakeManager struct {
 	l                      *logrus.Logger
 	l                      *logrus.Logger
 
 
 	// can be used to trigger outbound handshake for the given vpnIp
 	// can be used to trigger outbound handshake for the given vpnIp
-	trigger chan iputil.VpnIp
+	trigger chan netip.Addr
 }
 }
 
 
 type HandshakeHostInfo struct {
 type HandshakeHostInfo struct {
 	sync.Mutex
 	sync.Mutex
 
 
-	startTime   time.Time       // Time that we first started trying with this handshake
-	ready       bool            // Is the handshake ready
-	counter     int             // How many attempts have we made so far
-	lastRemotes []*udp.Addr     // Remotes that we sent to during the previous attempt
-	packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+	startTime   time.Time        // Time that we first started trying with this handshake
+	ready       bool             // Is the handshake ready
+	counter     int              // How many attempts have we made so far
+	lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
+	packetStore []*cachedPacket  // A set of packets to be transmitted once the handshake completes
 
 
 	hostinfo *HostInfo
 	hostinfo *HostInfo
 }
 }
@@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType,
 
 
 func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
 	return &HandshakeManager{
-		vpnIps:                 map[iputil.VpnIp]*HandshakeHostInfo{},
+		vpnIps:                 map[netip.Addr]*HandshakeHostInfo{},
 		indexes:                map[uint32]*HandshakeHostInfo{},
 		indexes:                map[uint32]*HandshakeHostInfo{},
 		mainHostMap:            mainHostMap,
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		lightHouse:             lightHouse,
 		outside:                outside,
 		outside:                outside,
 		config:                 config,
 		config:                 config,
-		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
-		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		trigger:                make(chan netip.Addr, config.triggerBuffer),
+		OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		messageMetrics:         config.messageMetrics,
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 	}
 	}
 }
 }
 
 
-func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
 	// First remote allow list check before we know the vpnIp
-	if addr != nil {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+	if addr.IsValid() {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 			return
 		}
 		}
@@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	}
 	}
 }
 }
 
 
-func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) {
 	hh := hm.queryVpnIp(vpnIp)
 	hh := hm.queryVpnIp(vpnIp)
 	if hh == nil {
 	if hh == nil {
 		return
 		return
@@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	}
 	}
 
 
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
 	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
-	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
+	remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes)
 
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// This is a very specific optimization for a fast lighthouse reply.
 	// This is a very specific optimization for a fast lighthouse reply.
@@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 	}
 	}
 
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
-	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
+	var sentTo []netip.AddrPort
+	hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
 		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
 		if err != nil {
@@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 		// Send a RelayRequest to all known Relay IP's
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 		for _, relay := range hostinfo.remotes.relays {
 			// Don't relay to myself, and don't relay through the host I'm trying to connect to
 			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
+			if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() {
 				continue
 				continue
 			}
 			}
-			relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
-			if relayHostInfo == nil || relayHostInfo.remote == nil {
+			relayHostInfo := hm.mainHostMap.QueryVpnIp(relay)
+			if relayHostInfo == nil || !relayHostInfo.remote.IsValid() {
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				hm.f.Handshake(*relay)
+				hm.f.Handshake(relay)
 				continue
 				continue
 			}
 			}
 			// Check the relay HostInfo to see if we already established a relay through it
 			// Check the relay HostInfo to see if we already established a relay through it
@@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 				case Requested:
 				case Requested:
 					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+
+					//TODO: IPV6-WORK
+					myVpnIpB := hm.f.myVpnNet.Addr().As4()
+					theirVpnIpB := vpnIp.As4()
+
 					// Re-send the CreateRelay request, in case the previous one was lost.
 					// Re-send the CreateRelay request, in case the previous one was lost.
 					m := NebulaControl{
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
-						RelayToIp:           uint32(vpnIp),
+						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
+						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					}
 					msg, err := m.Marshal()
 					msg, err := m.Marshal()
 					if err != nil {
 					if err != nil {
@@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 						// This must send over the hostinfo, not over hm.Hosts[ip]
 						// This must send over the hostinfo, not over hm.Hosts[ip]
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.lightHouse.myVpnIp,
+							"relayFrom":           hm.f.myVpnNet.Addr(),
 							"relayTo":             vpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
-							"relay":               *relay}).
+							"relay":               relay}).
 							Info("send CreateRelayRequest")
 							Info("send CreateRelayRequest")
 					}
 					}
 				default:
 				default:
@@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 				}
 				}
 			} else {
 			} else {
 				// No relays exist or requested yet.
 				// No relays exist or requested yet.
-				if relayHostInfo.remote != nil {
+				if relayHostInfo.remote.IsValid() {
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					if err != nil {
 					if err != nil {
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 					}
 
 
+					//TODO: IPV6-WORK
+					myVpnIpB := hm.f.myVpnNet.Addr().As4()
+					theirVpnIpB := vpnIp.As4()
+
 					m := NebulaControl{
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
-						RelayToIp:           uint32(vpnIp),
+						RelayFromIp:         binary.BigEndian.Uint32(myVpnIpB[:]),
+						RelayToIp:           binary.BigEndian.Uint32(theirVpnIpB[:]),
 					}
 					}
 					msg, err := m.Marshal()
 					msg, err := m.Marshal()
 					if err != nil {
 					if err != nil {
@@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 					} else {
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.lightHouse.myVpnIp,
+							"relayFrom":           hm.f.myVpnNet.Addr(),
 							"relayTo":             vpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"initiatorRelayIndex": idx,
-							"relay":               *relay}).
+							"relay":               relay}).
 							Info("send CreateRelayRequest")
 							Info("send CreateRelayRequest")
 					}
 					}
 				}
 				}
@@ -355,7 +364,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 
 
 // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
 // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
 // The 2nd argument will be true if the hostinfo is ready to transmit traffic
 // The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
+func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	hm.mainHostMap.RLock()
 	hm.mainHostMap.RLock()
 	h, ok := hm.mainHostMap.Hosts[vpnIp]
 	h, ok := hm.mainHostMap.Hosts[vpnIp]
 	hm.mainHostMap.RUnlock()
 	hm.mainHostMap.RUnlock()
@@ -372,7 +381,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 }
 }
 
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 	hm.Lock()
 
 
 	if hh, ok := hm.vpnIps[vpnIp]; ok {
 	if hh, ok := hm.vpnIps[vpnIp]; ok {
@@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 		vpnIp:           vpnIp,
 		vpnIp:           vpnIp,
 		HandshakePacket: make(map[uint8][]byte, 0),
 		HandshakePacket: make(map[uint8][]byte, 0),
 		relayState: RelayState{
 		relayState: RelayState{
-			relays:        map[iputil.VpnIp]struct{}{},
-			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relays:        map[netip.Addr]struct{}{},
+			relayForByIp:  map[netip.Addr]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
 		},
 	}
 	}
@@ -555,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
 func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	delete(c.vpnIps, hostinfo.vpnIp)
 	delete(c.vpnIps, hostinfo.vpnIp)
 	if len(c.vpnIps) == 0 {
 	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
+		c.vpnIps = map[netip.Addr]*HandshakeHostInfo{}
 	}
 	}
 
 
 	delete(c.indexes, hostinfo.localIndexId)
 	delete(c.indexes, hostinfo.localIndexId)
@@ -570,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 	}
 }
 }
 
 
-func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
 	hh := hm.queryVpnIp(vpnIp)
 	hh := hm.queryVpnIp(vpnIp)
 	if hh != nil {
 	if hh != nil {
 		return hh.hostinfo
 		return hh.hostinfo
@@ -579,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 
 
 }
 }
 
 
-func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 	return hm.vpnIps[vpnIp]
 	return hm.vpnIps[vpnIp]
@@ -599,7 +608,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
 	return hm.indexes[index]
 	return hm.indexes[index]
 }
 }
 
 
-func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
+func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix {
 	return c.mainHostMap.GetPreferredRanges()
 	return c.mainHostMap.GetPreferredRanges()
 }
 }
 
 

+ 9 - 9
handshake_manager_test.go

@@ -1,13 +1,12 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
@@ -15,10 +14,11 @@ import (
 
 
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	ip := netip.MustParseAddr("172.1.1.2")
+
+	preferredRanges := []netip.Prefix{localrange}
 	mainHM := newHostMap(l, vpncidr)
 	mainHM := newHostMap(l, vpncidr)
 	mainHM.preferredRanges.Store(&preferredRanges)
 	mainHM.preferredRanges.Store(&preferredRanges)
 
 
@@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.NotContains(t, blah.vpnIps, ip)
 	assert.NotContains(t, blah.vpnIps, ip)
 }
 }
 
 
-func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
+func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
 	for _, i := range tw.t.wheel {
 	for _, i := range tw.t.wheel {
 		n := i.Head
 		n := i.Head
 		for n != nil {
 		for n != nil {
@@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
 type mockEncWriter struct {
 type mockEncWriter struct {
 }
 }
 
 
-func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
 	return
 	return
 }
 }
 
 
@@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	return
 	return
 }
 }
 
 
-func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}
+func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}

+ 75 - 71
hostmap.go

@@ -3,18 +3,17 @@ package nebula
 import (
 import (
 	"errors"
 	"errors"
 	"net"
 	"net"
+	"net/netip"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // const ProbeLen = 100
 // const ProbeLen = 100
@@ -49,7 +48,7 @@ type Relay struct {
 	State       int
 	State       int
 	LocalIndex  uint32
 	LocalIndex  uint32
 	RemoteIndex uint32
 	RemoteIndex uint32
-	PeerIp      iputil.VpnIp
+	PeerIp      netip.Addr
 }
 }
 
 
 type HostMap struct {
 type HostMap struct {
@@ -57,9 +56,9 @@ type HostMap struct {
 	Indexes         map[uint32]*HostInfo
 	Indexes         map[uint32]*HostInfo
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	RemoteIndexes   map[uint32]*HostInfo
 	RemoteIndexes   map[uint32]*HostInfo
-	Hosts           map[iputil.VpnIp]*HostInfo
-	preferredRanges atomic.Pointer[[]*net.IPNet]
-	vpnCIDR         *net.IPNet
+	Hosts           map[netip.Addr]*HostInfo
+	preferredRanges atomic.Pointer[[]netip.Prefix]
+	vpnCIDR         netip.Prefix
 	l               *logrus.Logger
 	l               *logrus.Logger
 }
 }
 
 
@@ -69,12 +68,12 @@ type HostMap struct {
 type RelayState struct {
 type RelayState struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
-	relays        map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
-	relayForByIp  map[iputil.VpnIp]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
-	relayForByIdx map[uint32]*Relay         // Maps a local index to some Relay info
+	relays        map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer
+	relayForByIp  map[netip.Addr]*Relay   // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info
+	relayForByIdx map[uint32]*Relay       // Maps a local index to some Relay info
 }
 }
 
 
-func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
+func (rs *RelayState) DeleteRelay(ip netip.Addr) {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
 	delete(rs.relays, ip)
 	delete(rs.relays, ip)
@@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay {
 	return ret
 	return ret
 }
 }
 
 
-func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
 	r, ok := rs.relayForByIp[ip]
 	r, ok := rs.relayForByIp[ip]
 	return r, ok
 	return r, ok
 }
 }
 
 
-func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) {
+func (rs *RelayState) InsertRelayTo(ip netip.Addr) {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
 	rs.relays[ip] = struct{}{}
 	rs.relays[ip] = struct{}{}
 }
 }
 
 
-func (rs *RelayState) CopyRelayIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayIps() []netip.Addr {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	ret := make([]iputil.VpnIp, 0, len(rs.relays))
+	ret := make([]netip.Addr, 0, len(rs.relays))
 	for ip := range rs.relays {
 	for ip := range rs.relays {
 		ret = append(ret, ip)
 		ret = append(ret, ip)
 	}
 	}
 	return ret
 	return ret
 }
 }
 
 
-func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp {
+func (rs *RelayState) CopyRelayForIps() []netip.Addr {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
-	currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp))
+	currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp))
 	for relayIp := range rs.relayForByIp {
 	for relayIp := range rs.relayForByIp {
 		currentRelays = append(currentRelays, relayIp)
 		currentRelays = append(currentRelays, relayIp)
 	}
 	}
@@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 {
 	return ret
 	return ret
 }
 }
 
 
-func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) {
-	rs.Lock()
-	defer rs.Unlock()
-	r, ok := rs.relayForByIdx[localIdx]
-	if !ok {
-		return iputil.VpnIp(0), false
-	}
-	delete(rs.relayForByIdx, localIdx)
-	delete(rs.relayForByIp, r.PeerIp)
-	return r.PeerIp, true
-}
-
-func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool {
+func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
 	r, ok := rs.relayForByIp[vpnIp]
 	r, ok := rs.relayForByIp[vpnIp]
@@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re
 	return &newRelay, true
 	return &newRelay, true
 }
 }
 
 
-func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) {
+func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) {
 	rs.RLock()
 	rs.RLock()
 	defer rs.RUnlock()
 	defer rs.RUnlock()
 	r, ok := rs.relayForByIp[vpnIp]
 	r, ok := rs.relayForByIp[vpnIp]
@@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) {
 	return r, ok
 	return r, ok
 }
 }
 
 
-func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
+func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) {
 	rs.Lock()
 	rs.Lock()
 	defer rs.Unlock()
 	defer rs.Unlock()
 	rs.relayForByIp[ip] = r
 	rs.relayForByIp[ip] = r
@@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
 }
 }
 
 
 type HostInfo struct {
 type HostInfo struct {
-	remote          *udp.Addr
+	remote          netip.AddrPort
 	remotes         *RemoteList
 	remotes         *RemoteList
 	promoteCounter  atomic.Uint32
 	promoteCounter  atomic.Uint32
 	ConnectionState *ConnectionState
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	remoteIndexId   uint32
 	localIndexId    uint32
 	localIndexId    uint32
-	vpnIp           iputil.VpnIp
+	vpnIp           netip.Addr
 	recvError       atomic.Uint32
 	recvError       atomic.Uint32
-	remoteCidr      *cidr.Tree4[struct{}]
+	remoteCidr      *bart.Table[struct{}]
 	relayState      RelayState
 	relayState      RelayState
 
 
 	// HandshakePacket records the packets used to create this hostinfo
 	// HandshakePacket records the packets used to create this hostinfo
@@ -227,7 +214,7 @@ type HostInfo struct {
 	lastHandshakeTime uint64
 	lastHandshakeTime uint64
 
 
 	lastRoam       time.Time
 	lastRoam       time.Time
-	lastRoamRemote *udp.Addr
+	lastRoamRemote netip.AddrPort
 
 
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Synchronised via hostmap lock and not the hostinfo lock.
 	// Synchronised via hostmap lock and not the hostinfo lock.
@@ -254,7 +241,7 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 	dropped metrics.Counter
 }
 }
 
 
-func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
+func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap {
 	hm := newHostMap(l, vpnCIDR)
 	hm := newHostMap(l, vpnCIDR)
 
 
 	hm.reload(c, true)
 	hm.reload(c, true)
@@ -269,12 +256,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho
 	return hm
 	return hm
 }
 }
 
 
-func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
+func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap {
 	return &HostMap{
 	return &HostMap{
 		Indexes:       map[uint32]*HostInfo{},
 		Indexes:       map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		Relays:        map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
 		RemoteIndexes: map[uint32]*HostInfo{},
-		Hosts:         map[iputil.VpnIp]*HostInfo{},
+		Hosts:         map[netip.Addr]*HostInfo{},
 		vpnCIDR:       vpnCIDR,
 		vpnCIDR:       vpnCIDR,
 		l:             l,
 		l:             l,
 	}
 	}
@@ -282,11 +269,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
 
 
 func (hm *HostMap) reload(c *config.C, initial bool) {
 func (hm *HostMap) reload(c *config.C, initial bool) {
 	if initial || c.HasChanged("preferred_ranges") {
 	if initial || c.HasChanged("preferred_ranges") {
-		var preferredRanges []*net.IPNet
+		var preferredRanges []netip.Prefix
 		rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
 		rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
 
 
 		for _, rawPreferredRange := range rawPreferredRanges {
 		for _, rawPreferredRange := range rawPreferredRanges {
-			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
+			preferredRange, err := netip.ParsePrefix(rawPreferredRange)
 
 
 			if err != nil {
 			if err != nil {
 				hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
 				hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
@@ -378,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
 		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
 		delete(hm.Hosts, hostinfo.vpnIp)
 		delete(hm.Hosts, hostinfo.vpnIp)
 		if len(hm.Hosts) == 0 {
 		if len(hm.Hosts) == 0 {
-			hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+			hm.Hosts = map[netip.Addr]*HostInfo{}
 		}
 		}
 
 
 		if hostinfo.next != nil {
 		if hostinfo.next != nil {
@@ -461,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo {
 	return hm.queryVpnIp(vpnIp, nil)
 	return hm.queryVpnIp(vpnIp, nil)
 }
 }
 
 
-func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 
 
@@ -483,7 +470,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
 	return nil, nil, errors.New("unable to find host with relay")
 	return nil, nil, errors.New("unable to find host with relay")
 }
 }
 
 
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
+func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
 		hm.RUnlock()
@@ -535,7 +522,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
+func (hm *HostMap) GetPreferredRanges() []netip.Prefix {
 	//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
 	//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
 	return *hm.preferredRanges.Load()
 	return *hm.preferredRanges.Load()
 }
 }
@@ -560,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 
 
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
-func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
+func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
 	c := i.promoteCounter.Add(1)
 	if c%ifce.tryPromoteEvery.Load() == 0 {
 	if c%ifce.tryPromoteEvery.Load() == 0 {
 		remote := i.remote
 		remote := i.remote
 
 
 		// return early if we are already on a preferred remote
 		// return early if we are already on a preferred remote
-		if remote != nil {
-			rIP := remote.IP
+		if remote.IsValid() {
+			rIP := remote.Addr()
 			for _, l := range preferredRanges {
 			for _, l := range preferredRanges {
 				if l.Contains(rIP) {
 				if l.Contains(rIP) {
 					return
 					return
@@ -575,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 			}
 		}
 		}
 
 
-		i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) {
-			if remote != nil && (addr == nil || !preferred) {
+		i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) {
+			if remote.IsValid() && (!addr.IsValid() || !preferred) {
 				return
 				return
 			}
 			}
 
 
@@ -605,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	return nil
 	return nil
 }
 }
 
 
-func (i *HostInfo) SetRemote(remote *udp.Addr) {
+func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	// We copy here because we likely got this remote from a source that reuses the object
-	if !i.remote.Equals(remote) {
-		i.remote = remote.Copy()
-		i.remotes.LearnRemote(i.vpnIp, remote.Copy())
+	if i.remote != remote {
+		i.remote = remote
+		i.remotes.LearnRemote(i.vpnIp, remote)
 	}
 	}
 }
 }
 
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
-	if newRemote == nil {
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
+	if !newRemote.IsValid() {
 		// relays have nil udp Addrs
 		// relays have nil udp Addrs
 		return false
 		return false
 	}
 	}
 	currentRemote := i.remote
 	currentRemote := i.remote
-	if currentRemote == nil {
+	if !currentRemote.IsValid() {
 		i.SetRemote(newRemote)
 		i.SetRemote(newRemote)
 		return true
 		return true
 	}
 	}
@@ -631,11 +618,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	newIsPreferred := false
 	newIsPreferred := false
 	for _, l := range hm.GetPreferredRanges() {
 	for _, l := range hm.GetPreferredRanges() {
 		// return early if we are already on a preferred remote
 		// return early if we are already on a preferred remote
-		if l.Contains(currentRemote.IP) {
+		if l.Contains(currentRemote.Addr()) {
 			return false
 			return false
 		}
 		}
 
 
-		if l.Contains(newRemote.IP) {
+		if l.Contains(newRemote.Addr()) {
 			newIsPreferred = true
 			newIsPreferred = true
 		}
 		}
 	}
 	}
@@ -643,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 	if newIsPreferred {
 	if newIsPreferred {
 		// Consider this a roaming event
 		// Consider this a roaming event
 		i.lastRoam = time.Now()
 		i.lastRoam = time.Now()
-		i.lastRoamRemote = currentRemote.Copy()
+		i.lastRoamRemote = currentRemote
 
 
 		i.SetRemote(newRemote)
 		i.SetRemote(newRemote)
 
 
@@ -666,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 		return
 	}
 	}
 
 
-	remoteCidr := cidr.NewTree4[struct{}]()
+	remoteCidr := new(bart.Table[struct{}])
 	for _, ip := range c.Details.Ips {
 	for _, ip := range c.Details.Ips {
-		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+		//TODO: IPV6-WORK what to do when ip is invalid?
+		nip, _ := netip.AddrFromSlice(ip.IP)
+		nip = nip.Unmap()
+		bits, _ := ip.Mask.Size()
+		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
 	}
 	}
 
 
 	for _, n := range c.Details.Subnets {
 	for _, n := range c.Details.Subnets {
-		remoteCidr.AddCIDR(n, struct{}{})
+		//TODO: IPV6-WORK what to do when ip is invalid?
+		nip, _ := netip.AddrFromSlice(n.IP)
+		nip = nip.Unmap()
+		bits, _ := n.Mask.Size()
+		remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{})
 	}
 	}
 	i.remoteCidr = remoteCidr
 	i.remoteCidr = remoteCidr
 }
 }
@@ -697,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {
 
 
 // Utility functions
 // Utility functions
 
 
-func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
+func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr {
 	//FIXME: This function is pretty garbage
 	//FIXME: This function is pretty garbage
-	var ips []net.IP
+	var ips []netip.Addr
 	ifaces, _ := net.Interfaces()
 	ifaces, _ := net.Interfaces()
 	for _, i := range ifaces {
 	for _, i := range ifaces {
 		allow := allowList.AllowName(i.Name)
 		allow := allowList.AllowName(i.Name)
@@ -721,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP {
 				ip = v.IP
 				ip = v.IP
 			}
 			}
 
 
+			nip, ok := netip.AddrFromSlice(ip)
+			if !ok {
+				if l.Level >= logrus.DebugLevel {
+					l.WithField("localIp", ip).Debug("ip was invalid for netip")
+				}
+				continue
+			}
+			nip = nip.Unmap()
+
 			//TODO: Filtering out link local for now, this is probably the most correct thing
 			//TODO: Filtering out link local for now, this is probably the most correct thing
 			//TODO: Would be nice to filter out SLAAC MAC based ips as well
 			//TODO: Would be nice to filter out SLAAC MAC based ips as well
-			if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() {
-				allow := allowList.Allow(ip)
+			if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false {
+				allow := allowList.Allow(nip)
 				if l.Level >= logrus.TraceLevel {
 				if l.Level >= logrus.TraceLevel {
-					l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow")
+					l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow")
 				}
 				}
 				if !allow {
 				if !allow {
 					continue
 					continue
 				}
 				}
 
 
-				ips = append(ips, ip)
+				ips = append(ips, nip)
 			}
 			}
 		}
 		}
 	}
 	}
-	return &ips
+	return ips
 }
 }

+ 25 - 34
hostmap_test.go

@@ -1,7 +1,7 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -13,18 +13,15 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	hm := newHostMap(
 	hm := newHostMap(
 		l,
 		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
+		netip.MustParsePrefix("10.0.0.1/24"),
 	)
 	)
 
 
 	f := &Interface{}
 	f := &Interface{}
 
 
-	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
-	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
-	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
-	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
 
 
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h4, f)
 	hm.unlockedAddHostInfo(h3, f)
 	hm.unlockedAddHostInfo(h3, f)
@@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 	hm.unlockedAddHostInfo(h1, f)
 
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 	hm.MakePrimary(h3)
 
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 	hm.MakePrimary(h4)
 
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 	hm.MakePrimary(h4)
 
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	hm := newHostMap(
 	hm := newHostMap(
 		l,
 		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
+		netip.MustParsePrefix("10.0.0.1/24"),
 	)
 	)
 
 
 	f := &Interface{}
 	f := &Interface{}
 
 
-	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
-	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
-	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
-	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
-	h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
-	h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+	h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1}
+	h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2}
+	h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3}
+	h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4}
+	h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5}
+	h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6}
 
 
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h6, f)
 	hm.unlockedAddHostInfo(h5, f)
 	hm.unlockedAddHostInfo(h5, f)
@@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h)
 	assert.Nil(t, h)
 
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 	assert.Nil(t, h1.next)
 
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 	assert.Nil(t, h3.next)
 
 
 	// Make sure we go h2 -> h4 -> h5
 	// Make sure we go h2 -> h4 -> h5
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 	assert.Nil(t, h5.next)
 
 
 	// Make sure we go h2 -> h4
 	// Make sure we go h2 -> h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
@@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 	assert.Nil(t, h2.next)
 
 
 	// Make sure we only have h4
 	// Make sure we only have h4
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
 	assert.Nil(t, prim.next)
@@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 	assert.Nil(t, h4.next)
 
 
 	// Make sure we have nil
 	// Make sure we have nil
-	prim = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1"))
 	assert.Nil(t, prim)
 	assert.Nil(t, prim)
 }
 }
 
 
@@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) {
 
 
 	hm := NewHostMapFromConfig(
 	hm := NewHostMapFromConfig(
 		l,
 		l,
-		&net.IPNet{
-			IP:   net.IP{10, 0, 0, 1},
-			Mask: net.IPMask{255, 255, 255, 0},
-		},
+		netip.MustParsePrefix("10.0.0.1/24"),
 		c,
 		c,
 	)
 	)
 
 
-	toS := func(ipn []*net.IPNet) []string {
+	toS := func(ipn []netip.Prefix) []string {
 		var s []string
 		var s []string
 		for _, n := range ipn {
 		for _, n := range ipn {
 			s = append(s, n.String())
 			s = append(s, n.String())

+ 4 - 2
hostmap_tester.go

@@ -5,9 +5,11 @@ package nebula
 
 
 // This file contains functions used to export information to the e2e testing framework
 // This file contains functions used to export information to the e2e testing framework
 
 
-import "github.com/slackhq/nebula/iputil"
+import (
+	"net/netip"
+)
 
 
-func (i *HostInfo) GetVpnIp() iputil.VpnIp {
+func (i *HostInfo) GetVpnIp() netip.Addr {
 	return i.vpnIp
 	return i.vpnIp
 }
 }
 
 

+ 20 - 24
inside.go

@@ -1,12 +1,13 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"net/netip"
+
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/noiseutil"
 	"github.com/slackhq/nebula/noiseutil"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
 func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 	}
 
 
 	// Ignore local broadcast packets
 	// Ignore local broadcast packets
-	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast {
+	if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr {
 		return
 		return
 	}
 	}
 
 
-	if fwPacket.RemoteIP == f.myVpnIp {
+	if fwPacket.RemoteIP == f.myVpnNet.Addr() {
 		// Immediately forward packets from self to self.
 		// Immediately forward packets from self to self.
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// This should only happen on Darwin-based and FreeBSD hosts, which
 		// routes packets from the Nebula IP to the Nebula IP through the Nebula
 		// routes packets from the Nebula IP to the Nebula IP through the Nebula
@@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 		return
 	}
 	}
 
 
-	// Ignore broadcast packets
-	if f.dropMulticast && isMulticast(fwPacket.RemoteIP) {
+	// Ignore multicast packets
+	if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() {
 		return
 		return
 	}
 	}
 
 
@@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 
 
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
 	if dropReason == nil {
-		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q)
+		f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
 
 
 	} else {
 	} else {
 		f.rejectInside(packet, out, q)
 		f.rejectInside(packet, out, q)
@@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
 		return
 		return
 	}
 	}
 
 
-	f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q)
+	f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
 }
 }
 
 
-func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
+func (f *Interface) Handshake(vpnIp netip.Addr) {
 	f.getOrHandshake(vpnIp, nil)
 	f.getOrHandshake(vpnIp, nil)
 }
 }
 
 
 // getOrHandshake returns nil if the vpnIp is not routable.
 // getOrHandshake returns nil if the vpnIp is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
-	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
+func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
+	if !f.myVpnNet.Contains(vpnIp) {
 		vpnIp = f.inside.RouteFor(vpnIp)
 		vpnIp = f.inside.RouteFor(vpnIp)
-		if vpnIp == 0 {
+		if !vpnIp.IsValid() {
 			return nil, false
 			return nil, false
 		}
 		}
 	}
 	}
@@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 		return
 	}
 	}
 
 
-	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 }
 
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
-func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
+func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) {
 	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
 	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 	})
@@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag
 
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.messageMetrics.Tx(t, st, 1)
-	f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0)
 }
 }
 
 
-func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) {
+func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
 }
@@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo,
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 }
 
 
-func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
+func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
 	if ci.eKey == nil {
 		//TODO: log warning
 		//TODO: log warning
 		return
 		return
 	}
 	}
-	useRelay := remote == nil && hostinfo.remote == nil
+	useRelay := !remote.IsValid() && !hostinfo.remote.IsValid()
 	fullOut := out
 	fullOut := out
 
 
 	if useRelay {
 	if useRelay {
@@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		return
 		return
 	}
 	}
 
 
-	if remote != nil {
+	if remote.IsValid() {
 		err = f.writers[q].WriteTo(out, remote)
 		err = f.writers[q].WriteTo(out, remote)
 		if err != nil {
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
 			hostinfo.logger(f.l).WithError(err).
 				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 				WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 		}
 		}
-	} else if hostinfo.remote != nil {
+	} else if hostinfo.remote.IsValid() {
 		err = f.writers[q].WriteTo(out, hostinfo.remote)
 		err = f.writers[q].WriteTo(out, hostinfo.remote)
 		if err != nil {
 		if err != nil {
 			hostinfo.logger(f.l).WithError(err).
 			hostinfo.logger(f.l).WithError(err).
@@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 		}
 		}
 	}
 	}
 }
 }
-
-func isMulticast(ip iputil.VpnIp) bool {
-	// Class D multicast
-	return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
-}

+ 36 - 11
interface.go

@@ -2,10 +2,11 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/binary"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"runtime"
 	"runtime"
 	"sync/atomic"
 	"sync/atomic"
@@ -16,7 +17,6 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/overlay"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 )
 )
@@ -63,8 +63,8 @@ type Interface struct {
 	serveDns           bool
 	serveDns           bool
 	createTime         time.Time
 	createTime         time.Time
 	lightHouse         *LightHouse
 	lightHouse         *LightHouse
-	localBroadcast     iputil.VpnIp
-	myVpnIp            iputil.VpnIp
+	myBroadcastAddr    netip.Addr
+	myVpnNet           netip.Prefix
 	dropLocalBroadcast bool
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	dropMulticast      bool
 	routines           int
 	routines           int
@@ -102,9 +102,9 @@ type EncWriter interface {
 		out []byte,
 		out []byte,
 		nocopy bool,
 		nocopy bool,
 	)
 	)
-	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
 	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
-	Handshake(vpnIp iputil.VpnIp)
+	Handshake(vpnIp netip.Addr)
 }
 }
 
 
 type sendRecvErrorConfig uint8
 type sendRecvErrorConfig uint8
@@ -115,10 +115,10 @@ const (
 	sendRecvErrorPrivate
 	sendRecvErrorPrivate
 )
 )
 
 
-func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool {
+func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool {
 	switch s {
 	switch s {
 	case sendRecvErrorPrivate:
 	case sendRecvErrorPrivate:
-		return ip.IsPrivate()
+		return ip.Addr().IsPrivate()
 	case sendRecvErrorAlways:
 	case sendRecvErrorAlways:
 		return true
 		return true
 	case sendRecvErrorNever:
 	case sendRecvErrorNever:
@@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	}
 	}
 
 
 	certificate := c.pki.GetCertState().Certificate
 	certificate := c.pki.GetCertState().Certificate
-	myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
+
+	myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+	if !ok {
+		return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP)
+	}
+
+	myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask)
+	if !ok {
+		return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask)
+	}
+
+	myVpnAddr = myVpnAddr.Unmap()
+	myVpnMask = myVpnMask.Unmap()
+
+	if myVpnAddr.BitLen() != myVpnMask.BitLen() {
+		return nil, fmt.Errorf("ip address and mask are different lengths in certificate")
+	}
+
+	ones, _ := certificate.Details.Ips[0].Mask.Size()
+	myVpnNet := netip.PrefixFrom(myVpnAddr, ones)
+
 	ifce := &Interface{
 	ifce := &Interface{
 		pki:                c.pki,
 		pki:                c.pki,
 		hostMap:            c.HostMap,
 		hostMap:            c.HostMap,
@@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		handshakeManager:   c.HandshakeManager,
 		handshakeManager:   c.HandshakeManager,
 		createTime:         time.Now(),
 		createTime:         time.Now(),
 		lightHouse:         c.lightHouse,
 		lightHouse:         c.lightHouse,
-		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
 		dropMulticast:      c.DropMulticast,
 		routines:           c.routines,
 		routines:           c.routines,
 		version:            c.version,
 		version:            c.version,
 		writers:            make([]udp.Conn, c.routines),
 		writers:            make([]udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
-		myVpnIp:            myVpnIp,
+		myVpnNet:           myVpnNet,
 		relayManager:       c.relayManager,
 		relayManager:       c.relayManager,
 
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
@@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 		l: c.l,
 	}
 	}
 
 
+	if myVpnAddr.Is4() {
+		addr := myVpnNet.Masked().Addr().As4()
+		binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask))
+		ifce.myBroadcastAddr = netip.AddrFrom4(addr)
+	}
+
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.tryPromoteEvery.Store(c.tryPromoteEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))

+ 2 - 0
iputil/packet.go

@@ -6,6 +6,8 @@ import (
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
 
 
+//TODO: IPV6-WORK can probably delete this
+
 const (
 const (
 	// Need 96 bytes for the largest reject packet:
 	// Need 96 bytes for the largest reject packet:
 	// - 20 byte ipv4 header
 	// - 20 byte ipv4 header

+ 0 - 93
iputil/util.go

@@ -1,93 +0,0 @@
-package iputil
-
-import (
-	"encoding/binary"
-	"fmt"
-	"net"
-	"net/netip"
-)
-
-type VpnIp uint32
-
-const maxIPv4StringLen = len("255.255.255.255")
-
-func (ip VpnIp) String() string {
-	b := make([]byte, maxIPv4StringLen)
-
-	n := ubtoa(b, 0, byte(ip>>24))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip>>16&255))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip>>8&255))
-	b[n] = '.'
-	n++
-
-	n += ubtoa(b, n, byte(ip&255))
-	return string(b[:n])
-}
-
-func (ip VpnIp) MarshalJSON() ([]byte, error) {
-	return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil
-}
-
-func (ip VpnIp) ToIP() net.IP {
-	nip := make(net.IP, 4)
-	binary.BigEndian.PutUint32(nip, uint32(ip))
-	return nip
-}
-
-func (ip VpnIp) ToNetIpAddr() netip.Addr {
-	var nip [4]byte
-	binary.BigEndian.PutUint32(nip[:], uint32(ip))
-	return netip.AddrFrom4(nip)
-}
-
-func Ip2VpnIp(ip []byte) VpnIp {
-	if len(ip) == 16 {
-		return VpnIp(binary.BigEndian.Uint32(ip[12:16]))
-	}
-	return VpnIp(binary.BigEndian.Uint32(ip))
-}
-
-func ToNetIpAddr(ip net.IP) (netip.Addr, error) {
-	addr, ok := netip.AddrFromSlice(ip)
-	if !ok {
-		return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip)
-	}
-	return addr, nil
-}
-
-func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) {
-	addr, err := ToNetIpAddr(ipNet.IP)
-	if err != nil {
-		return netip.Prefix{}, err
-	}
-	ones, bits := ipNet.Mask.Size()
-	if ones == 0 && bits == 0 {
-		return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet)
-	}
-	return netip.PrefixFrom(addr, ones), nil
-}
-
-// ubtoa encodes the string form of the integer v to dst[start:] and
-// returns the number of bytes written to dst. The caller must ensure
-// that dst has sufficient length.
-func ubtoa(dst []byte, start int, v byte) int {
-	if v < 10 {
-		dst[start] = v + '0'
-		return 1
-	} else if v < 100 {
-		dst[start+1] = v%10 + '0'
-		dst[start] = v/10 + '0'
-		return 2
-	}
-
-	dst[start+2] = v%10 + '0'
-	dst[start+1] = (v/10)%10 + '0'
-	dst[start] = v/100 + '0'
-	return 3
-}

+ 0 - 17
iputil/util_test.go

@@ -1,17 +0,0 @@
-package iputil
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestVpnIp_String(t *testing.T) {
-	assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String())
-	assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String())
-	assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String())
-	assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String())
-	assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String())
-	assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String())
-}

+ 212 - 188
lighthouse.go

@@ -7,16 +7,16 @@ import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
+	"strconv"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 )
 )
@@ -26,25 +26,18 @@ import (
 
 
 var ErrHostNotKnown = errors.New("host not known")
 var ErrHostNotKnown = errors.New("host not known")
 
 
-type netIpAndPort struct {
-	ip   net.IP
-	port uint16
-}
-
 type LightHouse struct {
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
 	sync.RWMutex //Because we concurrently read and write to our maps
 	ctx          context.Context
 	ctx          context.Context
 	amLighthouse bool
 	amLighthouse bool
-	myVpnIp      iputil.VpnIp
-	myVpnZeros   iputil.VpnIp
-	myVpnNet     *net.IPNet
+	myVpnNet     netip.Prefix
 	punchConn    udp.Conn
 	punchConn    udp.Conn
 	punchy       *Punchy
 	punchy       *Punchy
 
 
 	// Local cache of answers from light houses
 	// Local cache of answers from light houses
 	// map of vpn Ip to answers
 	// map of vpn Ip to answers
-	addrMap map[iputil.VpnIp]*RemoteList
+	addrMap map[netip.Addr]*RemoteList
 
 
 	// filters remote addresses allowed for each host
 	// filters remote addresses allowed for each host
 	// - When we are a lighthouse, this filters what addresses we store and
 	// - When we are a lighthouse, this filters what addresses we store and
@@ -57,26 +50,26 @@ type LightHouse struct {
 	localAllowList atomic.Pointer[LocalAllowList]
 	localAllowList atomic.Pointer[LocalAllowList]
 
 
 	// used to trigger the HandshakeManager when we receive HostQueryReply
 	// used to trigger the HandshakeManager when we receive HostQueryReply
-	handshakeTrigger chan<- iputil.VpnIp
+	handshakeTrigger chan<- netip.Addr
 
 
 	// staticList exists to avoid having a bool in each addrMap entry
 	// staticList exists to avoid having a bool in each addrMap entry
 	// since static should be rare
 	// since static should be rare
-	staticList  atomic.Pointer[map[iputil.VpnIp]struct{}]
-	lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}]
+	staticList  atomic.Pointer[map[netip.Addr]struct{}]
+	lighthouses atomic.Pointer[map[netip.Addr]struct{}]
 
 
 	interval     atomic.Int64
 	interval     atomic.Int64
 	updateCancel context.CancelFunc
 	updateCancel context.CancelFunc
 	ifce         EncWriter
 	ifce         EncWriter
 	nebulaPort   uint32 // 32 bits because protobuf does not have a uint16
 	nebulaPort   uint32 // 32 bits because protobuf does not have a uint16
 
 
-	advertiseAddrs atomic.Pointer[[]netIpAndPort]
+	advertiseAddrs atomic.Pointer[[]netip.AddrPort]
 
 
 	// IP's of relays that can be used by peers to access me
 	// IP's of relays that can be used by peers to access me
-	relaysForMe atomic.Pointer[[]iputil.VpnIp]
+	relaysForMe atomic.Pointer[[]netip.Addr]
 
 
-	queryChan chan iputil.VpnIp
+	queryChan chan netip.Addr
 
 
-	calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
+	calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
 
 
 	metrics           *MessageMetrics
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
 	metricHolepunchTx metrics.Counter
@@ -85,7 +78,7 @@ type LightHouse struct {
 
 
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
 // addrMap should be nil unless this is during a config reload
 // addrMap should be nil unless this is during a config reload
-func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) {
+func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) {
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	nebulaPort := uint32(c.GetInt("listen.port", 0))
 	if amLighthouse && nebulaPort == 0 {
 	if amLighthouse && nebulaPort == 0 {
@@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		if err != nil {
 		if err != nil {
 			return nil, util.NewContextualError("Failed to get listening port", nil, err)
 			return nil, util.NewContextualError("Failed to get listening port", nil, err)
 		}
 		}
-		nebulaPort = uint32(uPort.Port)
+		nebulaPort = uint32(uPort.Port())
 	}
 	}
 
 
-	ones, _ := myVpnNet.Mask.Size()
 	h := LightHouse{
 	h := LightHouse{
 		ctx:          ctx,
 		ctx:          ctx,
 		amLighthouse: amLighthouse,
 		amLighthouse: amLighthouse,
-		myVpnIp:      iputil.Ip2VpnIp(myVpnNet.IP),
-		myVpnZeros:   iputil.VpnIp(32 - ones),
 		myVpnNet:     myVpnNet,
 		myVpnNet:     myVpnNet,
-		addrMap:      make(map[iputil.VpnIp]*RemoteList),
+		addrMap:      make(map[netip.Addr]*RemoteList),
 		nebulaPort:   nebulaPort,
 		nebulaPort:   nebulaPort,
 		punchConn:    pc,
 		punchConn:    pc,
 		punchy:       p,
 		punchy:       p,
-		queryChan:    make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
+		queryChan:    make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
 		l:            l,
 		l:            l,
 	}
 	}
-	lighthouses := make(map[iputil.VpnIp]struct{})
+	lighthouses := make(map[netip.Addr]struct{})
 	h.lighthouses.Store(&lighthouses)
 	h.lighthouses.Store(&lighthouses)
-	staticList := make(map[iputil.VpnIp]struct{})
+	staticList := make(map[netip.Addr]struct{})
 	h.staticList.Store(&staticList)
 	h.staticList.Store(&staticList)
 
 
 	if c.GetBool("stats.lighthouse_metrics", false) {
 	if c.GetBool("stats.lighthouse_metrics", false) {
@@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 	return &h, nil
 	return &h, nil
 }
 }
 
 
-func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} {
 	return *lh.staticList.Load()
 	return *lh.staticList.Load()
 }
 }
 
 
-func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} {
+func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} {
 	return *lh.lighthouses.Load()
 	return *lh.lighthouses.Load()
 }
 }
 
 
@@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList {
 	return lh.localAllowList.Load()
 	return lh.localAllowList.Load()
 }
 }
 
 
-func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort {
+func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort {
 	return *lh.advertiseAddrs.Load()
 	return *lh.advertiseAddrs.Load()
 }
 }
 
 
-func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
+func (lh *LightHouse) GetRelaysForMe() []netip.Addr {
 	return *lh.relaysForMe.Load()
 	return *lh.relaysForMe.Load()
 }
 }
 
 
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
+func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] {
 	return lh.calculatedRemotes.Load()
 	return lh.calculatedRemotes.Load()
 }
 }
 
 
@@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 {
 func (lh *LightHouse) reload(c *config.C, initial bool) error {
 func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	if initial || c.HasChanged("lighthouse.advertise_addrs") {
 	if initial || c.HasChanged("lighthouse.advertise_addrs") {
 		rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
 		rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{})
-		advAddrs := make([]netIpAndPort, 0)
+		advAddrs := make([]netip.AddrPort, 0)
 
 
 		for i, rawAddr := range rawAdvAddrs {
 		for i, rawAddr := range rawAdvAddrs {
-			fIp, fPort, err := udp.ParseIPAndPort(rawAddr)
+			host, sport, err := net.SplitHostPort(rawAddr)
 			if err != nil {
 			if err != nil {
 				return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
 				return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
 			}
 			}
 
 
-			if fPort == 0 {
-				fPort = uint16(lh.nebulaPort)
+			ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
+			if err != nil {
+				return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+			}
+			if len(ips) == 0 {
+				return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil)
+			}
+
+			port, err := strconv.Atoi(sport)
+			if err != nil {
+				return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err)
+			}
+
+			if port == 0 {
+				port = int(lh.nebulaPort)
 			}
 			}
 
 
-			if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) {
+			//TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used
+			ip := ips[0].Unmap()
+			if lh.myVpnNet.Contains(ip) {
 				lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
 				lh.l.WithField("addr", rawAddr).WithField("entry", i+1).
 					Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
 					Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range")
 				continue
 				continue
 			}
 			}
 
 
-			advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort})
+			advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port)))
 		}
 		}
 
 
 		lh.advertiseAddrs.Store(&advAddrs)
 		lh.advertiseAddrs.Store(&advAddrs)
@@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 			lh.RUnlock()
 			lh.RUnlock()
 		}
 		}
 		// Build a new list based on current config.
 		// Build a new list based on current config.
-		staticList := make(map[iputil.VpnIp]struct{})
-		err := lh.loadStaticMap(c, lh.myVpnNet, staticList)
+		staticList := make(map[netip.Addr]struct{})
+		err := lh.loadStaticMap(c, staticList)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	}
 	}
 
 
 	if initial || c.HasChanged("lighthouse.hosts") {
 	if initial || c.HasChanged("lighthouse.hosts") {
-		lhMap := make(map[iputil.VpnIp]struct{})
-		err := lh.parseLighthouses(c, lh.myVpnNet, lhMap)
+		lhMap := make(map[netip.Addr]struct{})
+		err := lh.parseLighthouses(c, lhMap)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 			if len(c.GetStringSlice("relay.relays", nil)) > 0 {
 			if len(c.GetStringSlice("relay.relays", nil)) > 0 {
 				lh.l.Info("Ignoring relays from config because am_relay is true")
 				lh.l.Info("Ignoring relays from config because am_relay is true")
 			}
 			}
-			relaysForMe := []iputil.VpnIp{}
+			relaysForMe := []netip.Addr{}
 			lh.relaysForMe.Store(&relaysForMe)
 			lh.relaysForMe.Store(&relaysForMe)
 		case false:
 		case false:
-			relaysForMe := []iputil.VpnIp{}
+			relaysForMe := []netip.Addr{}
 			for _, v := range c.GetStringSlice("relay.relays", nil) {
 			for _, v := range c.GetStringSlice("relay.relays", nil) {
 				lh.l.WithField("relay", v).Info("Read relay from config")
 				lh.l.WithField("relay", v).Info("Read relay from config")
 
 
-				configRIP := net.ParseIP(v)
-				if configRIP != nil {
-					relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP))
+				configRIP, err := netip.ParseAddr(v)
+				//TODO: We could print the error here
+				if err == nil {
+					relaysForMe = append(relaysForMe, configRIP)
 				}
 				}
 			}
 			}
 			lh.relaysForMe.Store(&relaysForMe)
 			lh.relaysForMe.Store(&relaysForMe)
@@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error {
 	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
 	lhs := c.GetStringSlice("lighthouse.hosts", []string{})
 	if lh.amLighthouse && len(lhs) != 0 {
 	if lh.amLighthouse && len(lhs) != 0 {
 		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 		lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config")
 	}
 	}
 
 
 	for i, host := range lhs {
 	for i, host := range lhs {
-		ip := net.ParseIP(host)
-		if ip == nil {
-			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil)
+		ip, err := netip.ParseAddr(host)
+		if err != nil {
+			return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err)
 		}
 		}
-		if !tunCidr.Contains(ip) {
-			return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil)
+		if !lh.myVpnNet.Contains(ip) {
+			return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil)
 		}
 		}
-		lhMap[iputil.Ip2VpnIp(ip)] = struct{}{}
+		lhMap[ip] = struct{}{}
 	}
 	}
 
 
 	if !lh.amLighthouse && len(lhMap) == 0 {
 	if !lh.amLighthouse && len(lhMap) == 0 {
@@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) {
 	return network, nil
 	return network, nil
 }
 }
 
 
-func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error {
 	d, err := getStaticMapCadence(c)
 	d, err := getStaticMapCadence(c)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 		return err
 		return err
 	}
 	}
 
 
-	lookup_timeout, err := getStaticMapLookupTimeout(c)
+	lookupTimeout, err := getStaticMapLookupTimeout(c)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	i := 0
 	i := 0
 
 
 	for k, v := range shm {
 	for k, v := range shm {
-		rip := net.ParseIP(fmt.Sprintf("%v", k))
-		if rip == nil {
-			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil)
+		vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k))
+		if err != nil {
+			return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err)
 		}
 		}
 
 
-		if !tunCidr.Contains(rip) {
-			return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil)
+		if !lh.myVpnNet.Contains(vpnIp) {
+			return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil)
 		}
 		}
 
 
-		vpnIp := iputil.Ip2VpnIp(rip)
 		vals, ok := v.([]interface{})
 		vals, ok := v.([]interface{})
 		if !ok {
 		if !ok {
 			vals = []interface{}{v}
 			vals = []interface{}{v}
@@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 			remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
 			remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v))
 		}
 		}
 
 
-		err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList)
+		err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	return nil
 	return nil
 }
 }
 
 
-func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) Query(ip netip.Addr) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip)
 		lh.QueryServer(ip)
 	}
 	}
@@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
 }
 }
 
 
 // QueryServer is asynchronous so no reply should be expected
 // QueryServer is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
+func (lh *LightHouse) QueryServer(ip netip.Addr) {
 	// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
 	// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
 	if lh.amLighthouse || lh.IsLighthouseIP(ip) {
 	if lh.amLighthouse || lh.IsLighthouseIP(ip) {
 		return
 		return
@@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
 	lh.queryChan <- ip
 	lh.queryChan <- ip
 }
 }
 
 
-func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList {
 	lh.RLock()
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
 		lh.RUnlock()
@@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
 // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
 // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
 // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
 // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
 // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
-func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) {
+func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) {
 	lh.RLock()
 	lh.RLock()
 	// Do we have an entry in the main cache?
 	// Do we have an entry in the main cache?
 	if v, ok := lh.addrMap[vpnIp]; ok {
 	if v, ok := lh.addrMap[vpnIp]; ok {
@@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in
 	return false, 0, nil
 	return false, 0, nil
 }
 }
 
 
-func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
+func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) {
 	// First we check the static mapping
 	// First we check the static mapping
 	// and do nothing if it is there
 	// and do nothing if it is there
 	if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
 	if _, ok := lh.GetStaticHostList()[vpnIp]; ok {
@@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) {
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
 // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it
-func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error {
+func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error {
 	lh.Lock()
 	lh.Lock()
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am := lh.unlockedGetRemoteList(vpnIp)
 	am.Lock()
 	am.Lock()
@@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
 	am.unlockedSetHostnamesResults(hr)
 	am.unlockedSetHostnamesResults(hr)
 
 
 	for _, addrPort := range hr.GetIPs() {
 	for _, addrPort := range hr.GetIPs() {
-
+		if !lh.shouldAdd(vpnIp, addrPort.Addr()) {
+			continue
+		}
 		switch {
 		switch {
 		case addrPort.Addr().Is4():
 		case addrPort.Addr().Is4():
-			to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
-			if !lh.unlockedShouldAddV4(vpnIp, to) {
-				continue
-			}
-			am.unlockedPrependV4(lh.myVpnIp, to)
+			am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
 		case addrPort.Addr().Is6():
 		case addrPort.Addr().Is6():
-			to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())
-			if !lh.unlockedShouldAddV6(vpnIp, to) {
-				continue
-			}
-			am.unlockedPrependV6(lh.myVpnIp, to)
+			am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()))
 		}
 		}
 	}
 	}
 
 
@@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t
 // addCalculatedRemotes adds any calculated remotes based on the
 // addCalculatedRemotes adds any calculated remotes based on the
 // lighthouse.calculated_remotes configuration. It returns true if any
 // lighthouse.calculated_remotes configuration. It returns true if any
 // calculated remotes were added
 // calculated remotes were added
-func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool {
 	tree := lh.getCalculatedRemotes()
 	tree := lh.getCalculatedRemotes()
 	if tree == nil {
 	if tree == nil {
 		return false
 		return false
 	}
 	}
-	ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+	calculatedRemotes, ok := tree.Lookup(vpnIp)
 	if !ok {
 	if !ok {
 		return false
 		return false
 	}
 	}
@@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 	defer am.Unlock()
 	defer am.Unlock()
 	lh.Unlock()
 	lh.Unlock()
 
 
-	am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4)
+	am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4)
 
 
 	return len(calculated) > 0
 	return len(calculated) > 0
 }
 }
 
 
 // unlockedGetRemoteList assumes you have the lh lock
 // unlockedGetRemoteList assumes you have the lh lock
-func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
+func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList {
 	am, ok := lh.addrMap[vpnIp]
 	am, ok := lh.addrMap[vpnIp]
 	if !ok {
 	if !ok {
 		am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
 		am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) })
@@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList {
 	return am
 	return am
 }
 }
 
 
-func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool {
-	switch {
-	case to.Is4():
-		ipBytes := to.As4()
-		ip := iputil.Ip2VpnIp(ipBytes[:])
-		allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip)
-		if lh.l.Level >= logrus.TraceLevel {
-			lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
-		}
-		if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) {
-			return false
-		}
-	case to.Is6():
-		ipBytes := to.As16()
-
-		hi := binary.BigEndian.Uint64(ipBytes[:8])
-		lo := binary.BigEndian.Uint64(ipBytes[8:])
-		allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo)
-		if lh.l.Level >= logrus.TraceLevel {
-			lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow")
-		}
-
-		// We don't check our vpn network here because nebula does not support ipv6 on the inside
-		if !allow {
-			return false
-		}
+func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool {
+	allow := lh.GetRemoteAllowList().Allow(vpnIp, to)
+	if lh.l.Level >= logrus.TraceLevel {
+		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
+	}
+	if !allow || lh.myVpnNet.Contains(to) {
+		return false
 	}
 	}
+
 	return true
 	return true
 }
 }
 
 
 // unlockedShouldAddV4 checks if to is allowed by our allow list
 // unlockedShouldAddV4 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool {
-	allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip))
+func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool {
+	ip := AddrPortFromIp4AndPort(to)
+	allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
 	if lh.l.Level >= logrus.TraceLevel {
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
 		lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 	}
 
 
-	if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) {
+	if !allow || lh.myVpnNet.Contains(ip.Addr()) {
 		return false
 		return false
 	}
 	}
 
 
@@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo
 }
 }
 
 
 // unlockedShouldAddV6 checks if to is allowed by our allow list
 // unlockedShouldAddV6 checks if to is allowed by our allow list
-func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool {
-	allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo)
+func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool {
+	ip := AddrPortFromIp6AndPort(to)
+	allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr())
 	if lh.l.Level >= logrus.TraceLevel {
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 	}
 
 
-	// We don't check our vpn network here because nebula does not support ipv6 on the inside
-	if !allow {
+	if !allow || lh.myVpnNet.Contains(ip.Addr()) {
 		return false
 		return false
 	}
 	}
 
 
@@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 	return ip
 	return ip
 }
 }
 
 
-func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool {
+func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool {
 	if _, ok := lh.GetLighthouses()[vpnIp]; ok {
 	if _, ok := lh.GetLighthouses()[vpnIp]; ok {
 		return true
 		return true
 	}
 	}
 	return false
 	return false
 }
 }
 
 
-func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta {
+func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta {
+	if vpnIp.Is6() {
+		//TODO: need to support ipv6
+		panic("ipv6 is not yet supported")
+	}
+
+	b := vpnIp.As4()
 	return &NebulaMeta{
 	return &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp: uint32(VpnIp),
+			VpnIp: binary.BigEndian.Uint32(b[:]),
 		},
 		},
 	}
 	}
 }
 }
 
 
-func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort {
-	ipp := Ip4AndPort{Port: port}
-	ipp.Ip = uint32(iputil.Ip2VpnIp(ip))
-	return &ipp
+func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort {
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], ip.Ip)
+	return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port))
+}
+
+func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort {
+	b := [16]byte{}
+	binary.BigEndian.PutUint64(b[:8], ip.Hi)
+	binary.BigEndian.PutUint64(b[8:], ip.Lo)
+	return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port))
 }
 }
 
 
 func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
 func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
@@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort {
 	}
 	}
 }
 }
 
 
-func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort {
-	return &Ip6AndPort{
-		Hi:   binary.BigEndian.Uint64(ip[:8]),
-		Lo:   binary.BigEndian.Uint64(ip[8:]),
-		Port: port,
-	}
-}
-
+// TODO: IPV6-WORK we can delete some more of these
 func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
 func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
 	ip6Addr := ip.As16()
 	ip6Addr := ip.As16()
 	return &Ip6AndPort{
 	return &Ip6AndPort{
@@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort {
 		Port: uint32(port),
 		Port: uint32(port),
 	}
 	}
 }
 }
-func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr {
-	ip := ipp.Ip
-	return udp.NewAddr(
-		net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)),
-		uint16(ipp.Port),
-	)
-}
-
-func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
-	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
-}
 
 
 func (lh *LightHouse) startQueryWorker() {
 func (lh *LightHouse) startQueryWorker() {
 	if lh.amLighthouse {
 	if lh.amLighthouse {
@@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() {
 	}()
 	}()
 }
 }
 
 
-func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
+func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) {
 	if lh.IsLighthouseIP(ip) {
 	if lh.IsLighthouseIP(ip) {
 		return
 		return
 	}
 	}
@@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() {
 	var v6 []*Ip6AndPort
 	var v6 []*Ip6AndPort
 
 
 	for _, e := range lh.GetAdvertiseAddrs() {
 	for _, e := range lh.GetAdvertiseAddrs() {
-		if ip := e.ip.To4(); ip != nil {
-			v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port)))
+		if e.Addr().Is4() {
+			v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port()))
 		} else {
 		} else {
-			v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port)))
+			v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port()))
 		}
 		}
 	}
 	}
 
 
 	lal := lh.GetLocalAllowList()
 	lal := lh.GetLocalAllowList()
-	for _, e := range *localIps(lh.l, lal) {
-		if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) {
+	for _, e := range localIps(lh.l, lal) {
+		if lh.myVpnNet.Contains(e) {
 			continue
 			continue
 		}
 		}
 
 
 		// Only add IPs that aren't my VPN/tun IP
 		// Only add IPs that aren't my VPN/tun IP
-		if ip := e.To4(); ip != nil {
-			v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort))
+		if e.Is4() {
+			v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort)))
 		} else {
 		} else {
-			v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort))
+			v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort)))
 		}
 		}
 	}
 	}
 
 
 	var relays []uint32
 	var relays []uint32
 	for _, r := range lh.GetRelaysForMe() {
 	for _, r := range lh.GetRelaysForMe() {
-		relays = append(relays, (uint32)(r))
+		//TODO: IPV6-WORK both relays and vpnip need ipv6 support
+		b := r.As4()
+		relays = append(relays, binary.BigEndian.Uint32(b[:]))
 	}
 	}
 
 
+	//TODO: IPV6-WORK both relays and vpnip need ipv6 support
+	b := lh.myVpnNet.Addr().As4()
+
 	m := &NebulaMeta{
 	m := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp:       uint32(lh.myVpnIp),
+			VpnIp:       binary.BigEndian.Uint32(b[:]),
 			Ip4AndPorts: v4,
 			Ip4AndPorts: v4,
 			Ip6AndPorts: v6,
 			Ip6AndPorts: v6,
 			RelayVpnIp:  relays,
 			RelayVpnIp:  relays,
@@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 }
 }
 
 
 func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
 func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc {
-	return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) {
+	return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) {
 		lhh.HandleRequest(rAddr, vpnIp, p, f)
 		lhh.HandleRequest(rAddr, vpnIp, p, f)
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) {
+func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) {
 	n := lhh.resetMeta()
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	err := n.Unmarshal(p)
 	if err != nil {
 	if err != nil {
@@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp,
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) {
 	// Exit if we don't answer queries
 	// Exit if we don't answer queries
 	if !lhh.lh.amLighthouse {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
@@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
 
 
 	//TODO: we can DRY this further
 	//TODO: we can DRY this further
 	reqVpnIp := n.Details.VpnIp
 	reqVpnIp := n.Details.VpnIp
+
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+	queryVpnIp := netip.AddrFrom4(b)
+
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
-	found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) {
+	found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		n.Type = NebulaMeta_HostQueryReply
 		n.Details.VpnIp = reqVpnIp
 		n.Details.VpnIp = reqVpnIp
@@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
 	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
 		n.Type = NebulaMeta_HostPunchNotification
-		n.Details.VpnIp = uint32(vpnIp)
-
+		//TODO: IPV6-WORK
+		b = vpnIp.As4()
+		n.Details.VpnIp = binary.BigEndian.Uint32(b[:])
 		lhh.coalesceAnswers(c, n)
 		lhh.coalesceAnswers(c, n)
 
 
 		return n.MarshalTo(lhh.pb)
 		return n.MarshalTo(lhh.pb)
@@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp,
 	}
 	}
 
 
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
-	w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0])
+
+	//TODO: IPV6-WORK
+	binary.BigEndian.PutUint32(b[:], reqVpnIp)
+	sendTo := netip.AddrFrom4(b)
+	w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 }
 
 
 func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
 func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
@@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
 	}
 	}
 
 
 	if c.relay != nil {
 	if c.relay != nil {
-		n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...)
+		//TODO: IPV6-WORK
+		relays := make([]uint32, len(c.relay.relay))
+		b := [4]byte{}
+		for i, _ := range relays {
+			b = c.relay.relay[i].As4()
+			relays[i] = binary.BigEndian.Uint32(b[:])
+		}
+		n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...)
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) {
+func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 		return
 	}
 	}
 
 
 	lhh.lh.Lock()
 	lhh.lh.Lock()
-	am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp))
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+	certVpnIp := netip.AddrFrom4(b)
+	am := lhh.lh.unlockedGetRemoteList(certVpnIp)
 	am.Lock()
 	am.Lock()
 	lhh.lh.Unlock()
 	lhh.lh.Unlock()
 
 
-	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
+	//TODO: IPV6-WORK
 	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
 	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
 	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
 	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
-	am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+
+	//TODO: IPV6-WORK
+	relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+	for i, _ := range n.Details.RelayVpnIp {
+		binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+		relays[i] = netip.AddrFrom4(b)
+	}
+	am.unlockedSetRelay(vpnIp, certVpnIp, relays)
 	am.Unlock()
 	am.Unlock()
 
 
 	// Non-blocking attempt to trigger, skip if it would block
 	// Non-blocking attempt to trigger, skip if it would block
 	select {
 	select {
-	case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp):
+	case lhh.lh.handshakeTrigger <- certVpnIp:
 	default:
 	default:
 	}
 	}
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
 	if !lhh.lh.amLighthouse {
 	if !lhh.lh.amLighthouse {
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
 			lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp)
@@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	}
 	}
 
 
 	//Simple check that the host sent this not someone else
 	//Simple check that the host sent this not someone else
-	if n.Details.VpnIp != uint32(vpnIp) {
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+	detailsVpnIp := netip.AddrFrom4(b)
+	if detailsVpnIp != vpnIp {
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update")
+			lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
 		}
 		}
 		return
 		return
 	}
 	}
@@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	am.Lock()
 	am.Lock()
 	lhh.lh.Unlock()
 	lhh.lh.Unlock()
 
 
-	certVpnIp := iputil.VpnIp(n.Details.VpnIp)
-	am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
-	am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp)
+	am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+
+	//TODO: IPV6-WORK
+	relays := make([]netip.Addr, len(n.Details.RelayVpnIp))
+	for i, _ := range n.Details.RelayVpnIp {
+		binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i])
+		relays[i] = netip.AddrFrom4(b)
+	}
+	am.unlockedSetRelay(vpnIp, detailsVpnIp, relays)
 	am.Unlock()
 	am.Unlock()
 
 
 	n = lhh.resetMeta()
 	n = lhh.resetMeta()
 	n.Type = NebulaMeta_HostUpdateNotificationAck
 	n.Type = NebulaMeta_HostUpdateNotificationAck
-	n.Details.VpnIp = uint32(vpnIp)
+
+	//TODO: IPV6-WORK
+	vpnIpB := vpnIp.As4()
+	n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:])
 	ln, err := n.MarshalTo(lhh.pb)
 	ln, err := n.MarshalTo(lhh.pb)
 
 
 	if err != nil {
 	if err != nil {
@@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 	w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 }
 
 
-func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) {
+func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 	if !lhh.lh.IsLighthouseIP(vpnIp) {
 		return
 		return
 	}
 	}
 
 
 	empty := []byte{0}
 	empty := []byte{0}
-	punch := func(vpnPeer *udp.Addr) {
-		if vpnPeer == nil {
+	punch := func(vpnPeer netip.AddrPort) {
+		if !vpnPeer.IsValid() {
 			return
 			return
 		}
 		}
 
 
@@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
 
 
 		if lhh.l.Level >= logrus.DebugLevel {
 		if lhh.l.Level >= logrus.DebugLevel {
 			//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
 			//TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp))
-			lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp))
+			//TODO: IPV6-WORK, make this debug line not suck
+			b := [4]byte{}
+			binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+			lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b))
 		}
 		}
 	}
 	}
 
 
 	for _, a := range n.Details.Ip4AndPorts {
 	for _, a := range n.Details.Ip4AndPorts {
-		punch(NewUDPAddrFromLH4(a))
+		punch(AddrPortFromIp4AndPort(a))
 	}
 	}
 
 
 	for _, a := range n.Details.Ip6AndPorts {
 	for _, a := range n.Details.Ip6AndPorts {
-		punch(NewUDPAddrFromLH6(a))
+		punch(AddrPortFromIp6AndPort(a))
 	}
 	}
 
 
 	// This sends a nebula test packet to the host trying to contact us. In the case
 	// This sends a nebula test packet to the host trying to contact us. In the case
 	// of a double nat or other difficult scenario, this may help establish
 	// of a double nat or other difficult scenario, this may help establish
 	// a tunnel.
 	// a tunnel.
 	if lhh.lh.punchy.GetRespond() {
 	if lhh.lh.punchy.GetRespond() {
-		queryVpnIp := iputil.VpnIp(n.Details.VpnIp)
+		//TODO: IPV6-WORK
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], n.Details.VpnIp)
+		queryVpnIp := netip.AddrFrom4(b)
 		go func() {
 		go func() {
 			time.Sleep(lhh.lh.punchy.GetRespondDelay())
 			time.Sleep(lhh.lh.punchy.GetRespondDelay())
 			if lhh.l.Level >= logrus.DebugLevel {
 			if lhh.l.Level >= logrus.DebugLevel {
@@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i
 		}()
 		}()
 	}
 	}
 }
 }
-
-// ipMaskContains checks if testIp is contained by ip after applying a cidr.
-// zeros is 32 - bits from net.IPMask.Size()
-func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool {
-	return (testIp^ip)>>zeros == 0
-}

+ 85 - 102
lighthouse_test.go

@@ -2,15 +2,14 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/binary"
 	"fmt"
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
-	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"gopkg.in/yaml.v2"
 	"gopkg.in/yaml.v2"
 )
 )
@@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) {
 	var m Ip4AndPort
 	var m Ip4AndPort
 	err := m.Unmarshal(b)
 	err := m.Unmarshal(b)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
-	assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String())
+	ip := netip.MustParseAddr("10.1.1.1")
+	bp := ip.As4()
+	assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp())
 }
 }
 
 
 func TestNewLhQuery(t *testing.T) {
 func TestNewLhQuery(t *testing.T) {
-	myIp := net.ParseIP("192.1.1.1")
-	myIpint := iputil.Ip2VpnIp(myIp)
+	myIp, err := netip.ParseAddr("192.1.1.1")
+	assert.NoError(t, err)
 
 
 	// Generating a new lh query should work
 	// Generating a new lh query should work
-	a := NewLhQueryByInt(myIpint)
+	a := NewLhQueryByInt(myIp)
 
 
 	// The result should be a nebulameta protobuf
 	// The result should be a nebulameta protobuf
 	assert.IsType(t, &NebulaMeta{}, a)
 	assert.IsType(t, &NebulaMeta{}, a)
@@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) {
 
 
 func Test_lhStaticMapping(t *testing.T) {
 func Test_lhStaticMapping(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
@@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) {
 
 
 func TestReloadLighthouseInterval(t *testing.T) {
 func TestReloadLighthouseInterval(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/16")
 	lh1 := "10.128.0.2"
 	lh1 := "10.128.0.2"
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
@@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) {
 	lh.ifce = &mockEncWriter{}
 	lh.ifce = &mockEncWriter{}
 
 
 	// The first one routine is kicked off by main.go currently, lets make sure that one dies
 	// The first one routine is kicked off by main.go currently, lets make sure that one dies
-	c.ReloadConfigString("lighthouse:\n  interval: 5")
+	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 5"))
 	assert.Equal(t, int64(5), lh.interval.Load())
 	assert.Equal(t, int64(5), lh.interval.Load())
 
 
 	// Subsequent calls are killed off by the LightHouse.Reload function
 	// Subsequent calls are killed off by the LightHouse.Reload function
-	c.ReloadConfigString("lighthouse:\n  interval: 10")
+	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 10"))
 	assert.Equal(t, int64(10), lh.interval.Load())
 	assert.Equal(t, int64(10), lh.interval.Load())
 
 
 	// If this completes then nothing is stealing our reload routine
 	// If this completes then nothing is stealing our reload routine
-	c.ReloadConfigString("lighthouse:\n  interval: 11")
+	assert.NoError(t, c.ReloadConfigString("lighthouse:\n  interval: 11"))
 	assert.Equal(t, int64(11), lh.interval.Load())
 	assert.Equal(t, int64(11), lh.interval.Load())
 }
 }
 
 
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func BenchmarkLighthouseHandleRequest(b *testing.B) {
 	l := test.NewLogger()
 	l := test.NewLogger()
-	_, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0")
+	myVpnNet := netip.MustParsePrefix("10.128.0.1/0")
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
 	lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil)
@@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		b.Fatal()
 		b.Fatal()
 	}
 	}
 
 
-	hAddr := udp.NewAddrFromString("4.5.6.7:12345")
-	hAddr2 := udp.NewAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = NewRemoteList(nil)
-	lh.addrMap[3].unlockedSetV4(
-		3,
-		3,
+	hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
+	hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
+
+	vpnIp3 := netip.MustParseAddr("0.0.0.3")
+	lh.addrMap[vpnIp3] = NewRemoteList(nil)
+	lh.addrMap[vpnIp3].unlockedSetV4(
+		vpnIp3,
+		vpnIp3,
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
-			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+			NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()),
+			NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()),
 		},
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
-	rAddr := udp.NewAddrFromString("1.2.2.3:12345")
-	rAddr2 := udp.NewAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = NewRemoteList(nil)
-	lh.addrMap[2].unlockedSetV4(
-		3,
-		3,
+	rAddr := netip.MustParseAddrPort("1.2.2.3:12345")
+	rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346")
+	vpnIp2 := netip.MustParseAddr("0.0.0.3")
+	lh.addrMap[vpnIp2] = NewRemoteList(nil)
+	lh.addrMap[vpnIp2].unlockedSetV4(
+		vpnIp3,
+		vpnIp3,
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
-			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+			NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()),
+			NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()),
 		},
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
@@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		p, err := req.Marshal()
 		p, err := req.Marshal()
 		assert.NoError(b, err)
 		assert.NoError(b, err)
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, 2, p, mw)
+			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
 		}
 		}
 	})
 	})
 	b.Run("found", func(b *testing.B) {
 	b.Run("found", func(b *testing.B) {
@@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 		assert.NoError(b, err)
 		assert.NoError(b, err)
 
 
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
-			lhh.HandleRequest(rAddr, 2, p, mw)
+			lhh.HandleRequest(rAddr, vpnIp2, p, mw)
 		}
 		}
 	})
 	})
 }
 }
@@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 func TestLighthouse_Memory(t *testing.T) {
 func TestLighthouse_Memory(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 
 
-	myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242}
-	myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242}
-	myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242}
-	myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242}
-	myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242}
-	myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243}
-	myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244}
-	myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245}
-	myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246}
-	myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247}
-	myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248}
-	myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249}
-	myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2"))
-
-	theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242}
-	theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242}
-	theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242}
-	theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242}
-	theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242}
-	theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3"))
+	myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242")
+	myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242")
+	myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242")
+	myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242")
+	myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242")
+	myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243")
+	myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244")
+	myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245")
+	myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246")
+	myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247")
+	myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248")
+	myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249")
+	myVpnIp := netip.MustParseAddr("10.128.0.2")
+
+	theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242")
+	theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242")
+	theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242")
+	theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242")
+	theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242")
+	theirVpnIp := netip.MustParseAddr("10.128.0.3")
 
 
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 	lhh := lh.NewRequestHandler()
 	lhh := lh.NewRequestHandler()
 
 
 	// Test that my first update responds with just that
 	// Test that my first update responds with just that
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2)
 
 
 	// Ensure we don't accumulate addresses
 	// Ensure we don't accumulate addresses
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3)
 
 
 	// Grow it back to 2
 	// Grow it back to 2
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh)
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Update a different host and ask about it
 	// Update a different host and ask about it
-	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
+	newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
 	r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4)
 
 
@@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	newLHHostUpdate(
 	newLHHostUpdate(
 		myUdpAddr0,
 		myUdpAddr0,
 		myVpnIp,
 		myVpnIp,
-		[]*udp.Addr{
+		[]netip.AddrPort{
 			myUdpAddr1,
 			myUdpAddr1,
 			myUdpAddr2,
 			myUdpAddr2,
 			myUdpAddr3,
 			myUdpAddr3,
@@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) {
 	)
 	)
 
 
 	// Make sure we won't add ips in our vpn network
 	// Make sure we won't add ips in our vpn network
-	bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242}
-	bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242}
-	good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242}
-	newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh)
+	bad1 := netip.MustParseAddrPort("10.128.0.99:4242")
+	bad2 := netip.MustParseAddrPort("10.128.0.100:4242")
+	good := netip.MustParseAddrPort("1.128.0.99:4242")
+	newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good)
 }
 }
@@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) {
 	c := config.NewC(l)
 	c := config.NewC(l)
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
 	c.Settings["listen"] = map[interface{}]interface{}{"port": 4242}
-	lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil)
+	lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
 	nc := map[interface{}]interface{}{
 	nc := map[interface{}]interface{}{
@@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) {
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 }
 }
 
 
-func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply {
+func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {
+	//TODO: IPV6-WORK
+	bip := queryVpnIp.As4()
 	req := &NebulaMeta{
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostQuery,
 		Type: NebulaMeta_HostQuery,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp: uint32(queryVpnIp),
+			VpnIp: binary.BigEndian.Uint32(bip[:]),
 		},
 		},
 	}
 	}
 
 
@@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh
 	return w.lastReply
 	return w.lastReply
 }
 }
 
 
-func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) {
+func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) {
+	//TODO: IPV6-WORK
+	bip := vpnIp.As4()
 	req := &NebulaMeta{
 	req := &NebulaMeta{
 		Type: NebulaMeta_HostUpdateNotification,
 		Type: NebulaMeta_HostUpdateNotification,
 		Details: &NebulaMetaDetails{
 		Details: &NebulaMetaDetails{
-			VpnIp:       uint32(vpnIp),
+			VpnIp:       binary.BigEndian.Uint32(bip[:]),
 			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
 			Ip4AndPorts: make([]*Ip4AndPort, len(addrs)),
 		},
 		},
 	}
 	}
 
 
 	for k, v := range addrs {
 	for k, v := range addrs {
-		req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)}
+		req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port())
 	}
 	}
 
 
 	b, err := req.Marshal()
 	b, err := req.Marshal()
@@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr,
 //	)
 //	)
 //}
 //}
 
 
-func Test_ipMaskContains(t *testing.T) {
-	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255"))))
-	assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-	assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1"))))
-}
-
 type testLhReply struct {
 type testLhReply struct {
 	nebType    header.MessageType
 	nebType    header.MessageType
 	nebSubType header.MessageSubType
 	nebSubType header.MessageSubType
-	vpnIp      iputil.VpnIp
+	vpnIp      netip.Addr
 	msg        *NebulaMeta
 	msg        *NebulaMeta
 }
 }
 
 
@@ -414,7 +416,7 @@ type testEncWriter struct {
 
 
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
 }
 }
-func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
+func (tw *testEncWriter) Handshake(vpnIp netip.Addr) {
 }
 }
 
 
 func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
 func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
@@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M
 	}
 	}
 }
 }
 
 
-func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
+func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)
 	err := msg.Unmarshal(p)
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
 	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
@@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess
 }
 }
 
 
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
 // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
-func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) {
-	if !assert.Len(t, have, len(want)) {
-		return
-	}
-
-	for k, w := range want {
-		if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) {
-			assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have)))
-		}
-	}
-}
-
-// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
-func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) {
+func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) {
 	if !assert.Len(t, have, len(want)) {
 	if !assert.Len(t, have, len(want)) {
 		return
 		return
 	}
 	}
 
 
 	for k, w := range want {
 	for k, w := range want {
-		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
-			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+		//TODO: IPV6-WORK
+		h := AddrPortFromIp4AndPort(have[k])
+		if !(h == w) {
+			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h))
 		}
 		}
 	}
 	}
 }
 }
-
-func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr {
-	addrs := make([]*udp.Addr, len(ips))
-	for k, v := range ips {
-		addrs[k] = NewUDPAddrFromLH4(v)
-	}
-	return addrs
-}

+ 22 - 8
main.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"time"
 	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
@@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	}
 	}
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 	l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started")
 
 
-	// TODO: make sure mask is 4 bytes
-	tunCidr := certificate.Details.Ips[0]
+	ones, _ := certificate.Details.Ips[0].Mask.Size()
+	addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP)
+	if !ok {
+		err = util.NewContextualError(
+			"Invalid ip address in certificate",
+			m{"vpnIp": certificate.Details.Ips[0].IP},
+			nil,
+		)
+		return nil, err
+	}
+	tunCidr := netip.PrefixFrom(addr, ones)
 
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	if err != nil {
 	if err != nil {
@@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 
 	if !configTest {
 	if !configTest {
 		rawListenHost := c.GetString("listen.host", "0.0.0.0")
 		rawListenHost := c.GetString("listen.host", "0.0.0.0")
-		var listenHost *net.IPAddr
+		var listenHost netip.Addr
 		if rawListenHost == "[::]" {
 		if rawListenHost == "[::]" {
 			// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
 			// Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve.
-			listenHost = &net.IPAddr{IP: net.IPv6zero}
+			listenHost = netip.IPv6Unspecified()
 
 
 		} else {
 		} else {
-			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
+			ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost)
 			if err != nil {
 			if err != nil {
 				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
 				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
 			}
 			}
+			if len(ips) == 0 {
+				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
+			}
+			listenHost = ips[0].Unmap()
 		}
 		}
 
 
 		for i := 0; i < routines; i++ {
 		for i := 0; i < routines; i++ {
-			l.Infof("listening %q %d", listenHost.IP, port)
-			udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64))
+			l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port)))
+			udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64))
 			if err != nil {
 			if err != nil {
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 				return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err)
 			}
 			}
@@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 				if err != nil {
 				if err != nil {
 					return nil, util.NewContextualError("Failed to get listening port", nil, err)
 					return nil, util.NewContextualError("Failed to get listening port", nil, err)
 				}
 				}
-				port = int(uPort.Port)
+				port = int(uPort.Port())
 			}
 			}
 		}
 		}
 	}
 	}

+ 48 - 47
outside.go

@@ -4,6 +4,7 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"net/netip"
 	"time"
 	"time"
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
@@ -11,7 +12,6 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/proto"
@@ -21,9 +21,10 @@ const (
 	minFwPacketLen = 4
 	minFwPacketLen = 4
 )
 )
 
 
+// TODO: IPV6-WORK this can likely be removed now
 func readOutsidePackets(f *Interface) udp.EncReader {
 func readOutsidePackets(f *Interface) udp.EncReader {
 	return func(
 	return func(
-		addr *udp.Addr,
+		addr netip.AddrPort,
 		out []byte,
 		out []byte,
 		packet []byte,
 		packet []byte,
 		header *header.H,
 		header *header.H,
@@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader {
 	}
 	}
 }
 }
 
 
-func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	err := h.Parse(packet)
 	if err != nil {
 	if err != nil {
 		// TODO: best if we return this and let caller log
 		// TODO: best if we return this and let caller log
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
 		if len(packet) > 1 {
-			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
+			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
 		}
 		}
 		return
 		return
 	}
 	}
 
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
 	//l.Error("in packet ", header, packet[HeaderLen:])
-	if addr != nil {
-		if ip4 := addr.IP.To4(); ip4 != nil {
-			if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
-				if f.l.Level >= logrus.DebugLevel {
-					f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
-				}
-				return
+	if ip.IsValid() {
+		if f.myVpnNet.Contains(ip.Addr()) {
+			if f.l.Level >= logrus.DebugLevel {
+				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
 			}
 			}
+			return
 		}
 		}
 	}
 	}
 
 
@@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 	switch h.Type {
 	switch h.Type {
 	case header.Message:
 	case header.Message:
 		// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
 		// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 			return
 		}
 		}
 
 
@@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			// Successfully validated the thing. Get rid of the Relay header.
 			// Successfully validated the thing. Get rid of the Relay header.
 			signedPayload = signedPayload[header.Len:]
 			signedPayload = signedPayload[header.Len:]
 			// Pull the Roaming parts up here, and return in all call paths.
 			// Pull the Roaming parts up here, and return in all call paths.
-			f.handleHostRoaming(hostinfo, addr)
+			f.handleHostRoaming(hostinfo, ip)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			f.connectionManager.In(hostinfo.localIndexId)
 			f.connectionManager.In(hostinfo.localIndexId)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			case TerminalType:
 			case TerminalType:
 				// If I am the target of this relay, process the unwrapped packet
 				// If I am the target of this relay, process the unwrapped packet
 				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
 				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-				f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+				f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
 				return
 				return
 			case ForwardingType:
 			case ForwardingType:
 				// Find the target HostInfo relay object
 				// Find the target HostInfo relay object
@@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 
 	case header.LightHouse:
 	case header.LightHouse:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 			return
 		}
 		}
 
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
 				Error("Failed to decrypt lighthouse packet")
 
 
@@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			return
 			return
 		}
 		}
 
 
-		lhf(addr, hostinfo.vpnIp, d)
+		lhf(ip, hostinfo.vpnIp, d)
 
 
 		// Fallthrough to the bottom to record incoming traffic
 		// Fallthrough to the bottom to record incoming traffic
 
 
 	case header.Test:
 	case header.Test:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 			return
 		}
 		}
 
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
 				Error("Failed to decrypt test packet")
 
 
@@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 		if h.Subtype == header.TestRequest {
 		if h.Subtype == header.TestRequest {
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
 			// to the new IP address before responding
-			f.handleHostRoaming(hostinfo, addr)
+			f.handleHostRoaming(hostinfo, ip)
 			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
 			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
 		}
 		}
 
 
@@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 
 	case header.Handshake:
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handshakeManager.HandleIncoming(addr, via, packet, h)
+		f.handshakeManager.HandleIncoming(ip, via, packet, h)
 		return
 		return
 
 
 	case header.RecvError:
 	case header.RecvError:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handleRecvError(addr, h)
+		f.handleRecvError(ip, h)
 		return
 		return
 
 
 	case header.CloseTunnel:
 	case header.CloseTunnel:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 			return
 		}
 		}
 
 
-		hostinfo.logger(f.l).WithField("udpAddr", addr).
+		hostinfo.logger(f.l).WithField("udpAddr", ip).
 			Info("Close tunnel received, tearing down.")
 			Info("Close tunnel received, tearing down.")
 
 
 		f.closeTunnel(hostinfo)
 		f.closeTunnel(hostinfo)
 		return
 		return
 
 
 	case header.Control:
 	case header.Control:
-		if !f.handleEncrypted(ci, addr, h) {
+		if !f.handleEncrypted(ci, ip, h) {
 			return
 			return
 		}
 		}
 
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
+			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
 				WithField("packet", packet).
 				WithField("packet", packet).
 				Error("Failed to decrypt Control packet")
 				Error("Failed to decrypt Control packet")
 			return
 			return
@@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 
 	default:
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
+		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
 		return
 		return
 	}
 	}
 
 
-	f.handleHostRoaming(hostinfo, addr)
+	f.handleHostRoaming(hostinfo, ip)
 
 
 	f.connectionManager.In(hostinfo.localIndexId)
 	f.connectionManager.In(hostinfo.localIndexId)
 }
 }
@@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 }
 
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
-	if addr != nil && !hostinfo.remote.Equals(addr) {
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
-			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) {
+	if ip.IsValid() && hostinfo.remote != ip {
+		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 			return
 		}
 		}
-		if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+		if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			}
 			return
 			return
 		}
 		}
 
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip).
 			Info("Host roamed to new udp ip/port.")
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(addr)
+		hostinfo.SetRemote(ip)
 	}
 	}
 
 
 }
 }
 
 
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
 	// If connectionstate exists and the replay protector allows, process packet
 	// If connectionstate exists and the replay protector allows, process packet
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
 	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
 	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
 	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
-		if addr != nil {
+		if addr.IsValid() {
 			f.maybeSendRecvError(addr, h.RemoteIndex)
 			f.maybeSendRecvError(addr, h.RemoteIndex)
 			return false
 			return false
 		} else {
 		} else {
@@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 
 
 	// Firewall packets are locally oriented
 	// Firewall packets are locally oriented
 	if incoming {
 	if incoming {
-		fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
-		fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
+		//TODO: IPV6-WORK
+		fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
+		fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
@@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
 		}
 		}
 	} else {
 	} else {
-		fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
-		fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
+		//TODO: IPV6-WORK
+		fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
+		fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
 			fp.RemotePort = 0
 			fp.RemotePort = 0
 			fp.LocalPort = 0
 			fp.LocalPort = 0
@@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	return true
 	return true
 }
 }
 
 
-func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {
-	if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) {
+func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) {
+	if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) {
 		f.sendRecvError(endpoint, index)
 		f.sendRecvError(endpoint, index)
 	}
 	}
 }
 }
 
 
-func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
+func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 	f.messageMetrics.Tx(header.RecvError, 0, 1)
 
 
 	//TODO: this should be a signed message so we can trust that we should drop the index
 	//TODO: this should be a signed message so we can trust that we should drop the index
@@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
 	}
 	}
 }
 }
 
 
-func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
+func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) {
 	if f.l.Level >= logrus.DebugLevel {
 	if f.l.Level >= logrus.DebugLevel {
 		f.l.WithField("index", h.RemoteIndex).
 		f.l.WithField("index", h.RemoteIndex).
 			WithField("udpAddr", addr).
 			WithField("udpAddr", addr).
@@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 		return
 		return
 	}
 	}
 
 
-	if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
+	if hostinfo.remote.IsValid() && hostinfo.remote != addr {
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return
 		return
 	}
 	}

+ 5 - 5
outside_test.go

@@ -2,10 +2,10 @@ package nebula
 
 
 import (
 import (
 	"net"
 	"net"
+	"net/netip"
 	"testing"
 	"testing"
 
 
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/net/ipv4"
 	"golang.org/x/net/ipv4"
 )
 )
@@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) {
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
 	assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
-	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
-	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
+	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2"))
+	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1"))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.RemotePort, uint16(3))
 	assert.Equal(t, p.LocalPort, uint16(4))
 	assert.Equal(t, p.LocalPort, uint16(4))
 
 
@@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) {
 
 
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 	assert.Equal(t, p.Protocol, uint8(2))
 	assert.Equal(t, p.Protocol, uint8(2))
-	assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1)))
-	assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2)))
+	assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1"))
+	assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2"))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.RemotePort, uint16(6))
 	assert.Equal(t, p.LocalPort, uint16(5))
 	assert.Equal(t, p.LocalPort, uint16(5))
 }
 }

+ 3 - 5
overlay/device.go

@@ -2,16 +2,14 @@ package overlay
 
 
 import (
 import (
 	"io"
 	"io"
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 )
 
 
 type Device interface {
 type Device interface {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	Activate() error
 	Activate() error
-	Cidr() *net.IPNet
+	Cidr() netip.Prefix
 	Name() string
 	Name() string
-	RouteFor(iputil.VpnIp) iputil.VpnIp
+	RouteFor(netip.Addr) netip.Addr
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 	NewMultiQueueReader() (io.ReadWriteCloser, error)
 }
 }

+ 19 - 25
overlay/route.go

@@ -1,34 +1,30 @@
 package overlay
 package overlay
 
 
 import (
 import (
-	"bytes"
 	"fmt"
 	"fmt"
 	"math"
 	"math"
 	"net"
 	"net"
+	"net/netip"
 	"runtime"
 	"runtime"
 	"strconv"
 	"strconv"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type Route struct {
 type Route struct {
 	MTU     int
 	MTU     int
 	Metric  int
 	Metric  int
-	Cidr    *net.IPNet
-	Via     *iputil.VpnIp
+	Cidr    netip.Prefix
+	Via     netip.Addr
 	Install bool
 	Install bool
 }
 }
 
 
 // Equal determines if a route that could be installed in the system route table is equal to another
 // Equal determines if a route that could be installed in the system route table is equal to another
 // Via is ignored since that is only consumed within nebula itself
 // Via is ignored since that is only consumed within nebula itself
 func (r Route) Equal(t Route) bool {
 func (r Route) Equal(t Route) bool {
-	if !r.Cidr.IP.Equal(t.Cidr.IP) {
-		return false
-	}
-	if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) {
+	if r.Cidr != t.Cidr {
 		return false
 		return false
 	}
 	}
 	if r.Metric != t.Metric {
 	if r.Metric != t.Metric {
@@ -51,21 +47,21 @@ func (r Route) String() string {
 	return s
 	return s
 }
 }
 
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
-	routeTree := cidr.NewTree4[iputil.VpnIp]()
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
+	routeTree := new(bart.Table[netip.Addr])
 	for _, r := range routes {
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
 		}
 		}
 
 
-		if r.Via != nil {
-			routeTree.AddCIDR(r.Cidr, *r.Via)
+		if r.Via.IsValid() {
+			routeTree.Insert(r.Cidr, r.Via)
 		}
 		}
 	}
 	}
 	return routeTree, nil
 	return routeTree, nil
 }
 }
 
 
-func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.routes")
 	r := c.Get("tun.routes")
@@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			MTU:     mtu,
 			MTU:     mtu,
 		}
 		}
 
 
-		_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+		r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 			return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err)
 		}
 		}
 
 
-		if !ipWithin(network, r.Cidr) {
+		if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() {
 			return nil, fmt.Errorf(
 			return nil, fmt.Errorf(
 				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
 				"entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v",
 				i+1,
 				i+1,
@@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 	return routes, nil
 	return routes, nil
 }
 }
 
 
-func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
+func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) {
 	var err error
 	var err error
 
 
 	r := c.Get("tun.unsafe_routes")
 	r := c.Get("tun.unsafe_routes")
@@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
 			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
 		}
 		}
 
 
-		nVia := net.ParseIP(via)
-		if nVia == nil {
-			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via)
+		viaVpnIp, err := netip.ParseAddr(via)
+		if err != nil {
+			return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
 		}
 		}
 
 
 		rRoute, ok := m["route"]
 		rRoute, ok := m["route"]
@@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1)
 		}
 		}
 
 
-		viaVpnIp := iputil.Ip2VpnIp(nVia)
-
 		install := true
 		install := true
 		rInstall, ok := m["install"]
 		rInstall, ok := m["install"]
 		if ok {
 		if ok {
@@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) {
 		}
 		}
 
 
 		r := Route{
 		r := Route{
-			Via:     &viaVpnIp,
+			Via:     viaVpnIp,
 			MTU:     mtu,
 			MTU:     mtu,
 			Metric:  metric,
 			Metric:  metric,
 			Install: install,
 			Install: install,
 		}
 		}
 
 
-		_, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute))
+		r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute))
 		if err != nil {
 		if err != nil {
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 			return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err)
 		}
 		}
 
 
-		if ipWithin(network, r.Cidr) {
+		if network.Contains(r.Cidr.Addr()) {
 			return nil, fmt.Errorf(
 			return nil, fmt.Errorf(
 				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
 				"entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v",
 				i+1,
 				i+1,

+ 27 - 16
overlay/route_test.go

@@ -2,11 +2,10 @@ package overlay
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 
 
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
@@ -14,7 +13,8 @@ import (
 func Test_parseRoutes(t *testing.T) {
 func Test_parseRoutes(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
-	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	assert.NoError(t, err)
 
 
 	// test no routes config
 	// test no routes config
 	routes, err := parseRoutes(c, n)
 	routes, err := parseRoutes(c, n)
@@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
 	routes, err = parseRoutes(c, n)
 	routes, err = parseRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope")
+	assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 
 	// below network range
 	// below network range
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
@@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) {
 func Test_parseUnsafeRoutes(t *testing.T) {
 func Test_parseUnsafeRoutes(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
-	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	assert.NoError(t, err)
 
 
 	// test no routes config
 	// test no routes config
 	routes, err := parseUnsafeRoutes(c, n)
 	routes, err := parseUnsafeRoutes(c, n)
@@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, n)
 	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope")
+	assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
 
 
 	// missing route
 	// missing route
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
@@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
 	routes, err = parseUnsafeRoutes(c, n)
 	routes, err = parseUnsafeRoutes(c, n)
 	assert.Nil(t, routes)
 	assert.Nil(t, routes)
-	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope")
+	assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
 
 
 	// within network range
 	// within network range
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
@@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) {
 func Test_makeRouteTree(t *testing.T) {
 func Test_makeRouteTree(t *testing.T) {
 	l := test.NewLogger()
 	l := test.NewLogger()
 	c := config.NewC(l)
 	c := config.NewC(l)
-	_, n, _ := net.ParseCIDR("10.0.0.0/24")
+	n, err := netip.ParsePrefix("10.0.0.0/24")
+	assert.NoError(t, err)
 
 
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
 	c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
 		map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
@@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) {
 	routeTree, err := makeRouteTree(l, routes, true)
 	routeTree, err := makeRouteTree(l, routes, true)
 	assert.NoError(t, err)
 	assert.NoError(t, err)
 
 
-	ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
-	ok, r := routeTree.MostSpecificContains(ip)
+	ip, err := netip.ParseAddr("1.0.0.2")
+	assert.NoError(t, err)
+	r, ok := routeTree.Lookup(ip)
 	assert.True(t, ok)
 	assert.True(t, ok)
-	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
 
 
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
-	ok, r = routeTree.MostSpecificContains(ip)
+	nip, err := netip.ParseAddr("192.168.0.1")
+	assert.NoError(t, err)
+	assert.Equal(t, nip, r)
+
+	ip, err = netip.ParseAddr("1.0.0.1")
+	assert.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
 	assert.True(t, ok)
 	assert.True(t, ok)
-	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
 
 
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
-	ok, r = routeTree.MostSpecificContains(ip)
+	nip, err = netip.ParseAddr("192.168.0.2")
+	assert.NoError(t, err)
+	assert.Equal(t, nip, r)
+
+	ip, err = netip.ParseAddr("1.1.0.1")
+	assert.NoError(t, err)
+	r, ok = routeTree.Lookup(ip)
 	assert.False(t, ok)
 	assert.False(t, ok)
 }
 }

+ 5 - 5
overlay/tun.go

@@ -1,7 +1,7 @@
 package overlay
 package overlay
 
 
 import (
 import (
-	"net"
+	"net/netip"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -11,9 +11,9 @@ import (
 const DefaultMTU = 1300
 const DefaultMTU = 1300
 
 
 // TODO: We may be able to remove routines
 // TODO: We may be able to remove routines
-type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error)
+type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error)
 
 
-func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
 	switch {
 	switch {
 	case c.GetBool("tun.disabled", false):
 	case c.GetBool("tun.disabled", false):
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
 		tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
@@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout
 }
 }
 
 
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
 func NewFdDeviceFromConfig(fd *int) DeviceFactory {
-	return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+	return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
 		return newTunFromFd(c, l, *fd, tunCidr)
 		return newTunFromFd(c, l, *fd, tunCidr)
 	}
 	}
 }
 }
 
 
-func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) {
+func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 	if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") {
 		return false, nil, nil
 		return false, nil, nil
 	}
 	}

+ 9 - 10
overlay/tun_android.go

@@ -6,27 +6,26 @@ package overlay
 import (
 import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"sync/atomic"
 	"sync/atomic"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	fd        int
 	fd        int
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	// Be sure not to call file.Fd() as it will set the fd to blocking mode.
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in Android")
 	return nil, fmt.Errorf("newTun not supported in Android")
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
@@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 41 - 18
overlay/tun_darwin.go

@@ -8,15 +8,15 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
+	"net/netip"
 	"os"
 	"os"
 	"sync/atomic"
 	"sync/atomic"
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 	netroute "golang.org/x/net/route"
 	netroute "golang.org/x/net/route"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
@@ -25,10 +25,10 @@ import (
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	Device     string
 	Device     string
-	cidr       *net.IPNet
+	cidr       netip.Prefix
 	DefaultMTU int
 	DefaultMTU int
 	Routes     atomic.Pointer[[]Route]
 	Routes     atomic.Pointer[[]Route]
-	routeTree  atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree  atomic.Pointer[bart.Table[netip.Addr]]
 	linkAddr   *netroute.LinkAddr
 	linkAddr   *netroute.LinkAddr
 	l          *logrus.Logger
 	l          *logrus.Logger
 
 
@@ -73,7 +73,7 @@ type ifreqMTU struct {
 	pad  [8]byte
 	pad  [8]byte
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	name := c.GetString("tun.dev", "")
 	name := c.GetString("tun.dev", "")
 	ifIndex := -1
 	ifIndex := -1
 	if name != "" && name != "utun" {
 	if name != "" && name != "utun" {
@@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
 	return
 	return
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 	return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
 }
 }
 
 
@@ -188,8 +188,13 @@ func (t *tun) Activate() error {
 
 
 	var addr, mask [4]byte
 	var addr, mask [4]byte
 
 
-	copy(addr[:], t.cidr.IP.To4())
-	copy(mask[:], t.cidr.Mask)
+	if !t.cidr.Addr().Is4() {
+		//TODO: IPV6-WORK
+		panic("need ipv6")
+	}
+
+	addr = t.cidr.Addr().As4()
+	copy(mask[:], prefixToMask(t.cidr))
 
 
 	s, err := unix.Socket(
 	s, err := unix.Socket(
 		unix.AF_INET,
 		unix.AF_INET,
@@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	ok, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, ok := t.routeTree.Load().Lookup(ip)
 	if ok {
 	if ok {
 		return r
 		return r
 	}
 	}
-
-	return 0
+	return netip.Addr{}
 }
 }
 
 
 // Get the LinkAddr for the interface of the given name
 // Get the LinkAddr for the interface of the given name
@@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error {
 	maskAddr := &netroute.Inet4Addr{}
 	maskAddr := &netroute.Inet4Addr{}
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
 	for _, r := range routes {
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
 
 
-		copy(routeAddr.IP[:], r.Cidr.IP.To4())
-		copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+		if !r.Cidr.Addr().Is4() {
+			//TODO: implement ipv6
+			panic("Cant handle ipv6 routes yet")
+		}
+
+		routeAddr.IP = r.Cidr.Addr().As4()
+		//TODO: we could avoid the copy
+		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
 
 
 		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
 		err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
 		if err != nil {
 		if err != nil {
@@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		copy(routeAddr.IP[:], r.Cidr.IP.To4())
-		copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4())
+		if r.Cidr.Addr().Is6() {
+			//TODO: implement ipv6
+			panic("Cant handle ipv6 routes yet")
+		}
+
+		routeAddr.IP = r.Cidr.Addr().As4()
+		copy(maskAddr.IP[:], prefixToMask(r.Cidr))
 
 
 		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
 		err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr)
 		if err != nil {
 		if err != nil {
@@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) {
 	return n - 4, err
 	return n - 4, err
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 
@@ -547,3 +562,11 @@ func (t *tun) Name() string {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 	return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
 }
 }
+
+func prefixToMask(prefix netip.Prefix) []byte {
+	pLen := 128
+	if prefix.Addr().Is4() {
+		pLen = 32
+	}
+	return net.CIDRMask(prefix.Bits(), pLen)
+}

+ 6 - 6
overlay/tun_disabled.go

@@ -3,7 +3,7 @@ package overlay
 import (
 import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"strings"
 	"strings"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
@@ -13,7 +13,7 @@ import (
 
 
 type disabledTun struct {
 type disabledTun struct {
 	read chan []byte
 	read chan []byte
-	cidr *net.IPNet
+	cidr netip.Prefix
 
 
 	// Track these metrics since we don't have the tun device to do it for us
 	// Track these metrics since we don't have the tun device to do it for us
 	tx metrics.Counter
 	tx metrics.Counter
@@ -21,7 +21,7 @@ type disabledTun struct {
 	l  *logrus.Logger
 	l  *logrus.Logger
 }
 }
 
 
-func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
+func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun {
 	tun := &disabledTun{
 	tun := &disabledTun{
 		cidr: cidr,
 		cidr: cidr,
 		read: make(chan []byte, queueLen),
 		read: make(chan []byte, queueLen),
@@ -43,11 +43,11 @@ func (*disabledTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
-func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
-	return 0
+func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
+	return netip.Addr{}
 }
 }
 
 
-func (t *disabledTun) Cidr() *net.IPNet {
+func (t *disabledTun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 11 - 12
overlay/tun_freebsd.go

@@ -9,7 +9,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"io/fs"
 	"io/fs"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"os/exec"
 	"os/exec"
 	"strconv"
 	"strconv"
@@ -17,10 +17,9 @@ import (
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 )
 )
 
 
@@ -48,10 +47,10 @@ type ifreqDestroy struct {
 
 
 type tun struct {
 type tun struct {
 	Device    string
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
@@ -79,11 +78,11 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	// Try to open existing tun device
 	// Try to open existing tun device
 	var file *os.File
 	var file *os.File
 	var err error
 	var err error
@@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
 	var err error
 	var err error
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 
@@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
 	for _, r := range routes {
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}

+ 9 - 10
overlay/tun_ios.go

@@ -7,32 +7,31 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"syscall"
 	"syscall"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 }
 }
 
 
-func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) {
+func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) {
 	return nil, fmt.Errorf("newTun not supported in iOS")
 	return nil, fmt.Errorf("newTun not supported in iOS")
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/tun")
 	t := &tun{
 	t := &tun{
 		cidr:            cidr,
 		cidr:            cidr,
@@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
@@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error {
 	return tr.f.Close()
 	return tr.f.Close()
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 56 - 35
overlay/tun_linux.go

@@ -4,19 +4,18 @@
 package overlay
 package overlay
 
 
 import (
 import (
-	"bytes"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
+	"net/netip"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"sync/atomic"
 	"sync/atomic"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 	"github.com/vishvananda/netlink"
 	"github.com/vishvananda/netlink"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
@@ -26,7 +25,7 @@ type tun struct {
 	io.ReadWriteCloser
 	io.ReadWriteCloser
 	fd          int
 	fd          int
 	Device      string
 	Device      string
-	cidr        *net.IPNet
+	cidr        netip.Prefix
 	MaxMTU      int
 	MaxMTU      int
 	DefaultMTU  int
 	DefaultMTU  int
 	TXQueueLen  int
 	TXQueueLen  int
@@ -34,7 +33,7 @@ type tun struct {
 	ioctlFd     uintptr
 	ioctlFd     uintptr
 
 
 	Routes          atomic.Pointer[[]Route]
 	Routes          atomic.Pointer[[]Route]
-	routeTree       atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree       atomic.Pointer[bart.Table[netip.Addr]]
 	routeChan       chan struct{}
 	routeChan       chan struct{}
 	useSystemRoutes bool
 	useSystemRoutes bool
 
 
@@ -65,7 +64,7 @@ type ifreqQLEN struct {
 	pad   [8]byte
 	pad   [8]byte
 }
 }
 
 
-func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) {
+func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) {
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 	file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
 
 
 	t, err := newTunGeneric(c, l, file, cidr)
 	t, err := newTunGeneric(c, l, file, cidr)
@@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet)
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) {
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
 	if err != nil {
 	if err != nil {
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
 		// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t
 	return t, nil
 	return t, nil
 }
 }
 
 
-func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) {
+func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) {
 	t := &tun{
 	t := &tun{
 		ReadWriteCloser: file,
 		ReadWriteCloser: file,
 		fd:              int(file.Fd()),
 		fd:              int(file.Fd()),
@@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return file, nil
 	return file, nil
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
@@ -275,8 +274,10 @@ func (t *tun) Activate() error {
 
 
 	var addr, mask [4]byte
 	var addr, mask [4]byte
 
 
-	copy(addr[:], t.cidr.IP.To4())
-	copy(mask[:], t.cidr.Mask)
+	//TODO: IPV6-WORK
+	addr = t.cidr.Addr().As4()
+	tmask := net.CIDRMask(t.cidr.Bits(), 32)
+	copy(mask[:], tmask)
 
 
 	s, err := unix.Socket(
 	s, err := unix.Socket(
 		unix.AF_INET,
 		unix.AF_INET,
@@ -364,14 +365,19 @@ func (t *tun) setMTU() {
 
 
 func (t *tun) setDefaultRoute() error {
 func (t *tun) setDefaultRoute() error {
 	// Default route
 	// Default route
-	dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask}
+
+	dr := &net.IPNet{
+		IP:   t.cidr.Masked().Addr().AsSlice(),
+		Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()),
+	}
+
 	nr := netlink.Route{
 	nr := netlink.Route{
 		LinkIndex: t.deviceIndex,
 		LinkIndex: t.deviceIndex,
 		Dst:       dr,
 		Dst:       dr,
 		MTU:       t.DefaultMTU,
 		MTU:       t.DefaultMTU,
 		AdvMSS:    t.advMSS(Route{}),
 		AdvMSS:    t.advMSS(Route{}),
 		Scope:     unix.RT_SCOPE_LINK,
 		Scope:     unix.RT_SCOPE_LINK,
-		Src:       t.cidr.IP,
+		Src:       net.IP(t.cidr.Addr().AsSlice()),
 		Protocol:  unix.RTPROT_KERNEL,
 		Protocol:  unix.RTPROT_KERNEL,
 		Table:     unix.RT_TABLE_MAIN,
 		Table:     unix.RT_TABLE_MAIN,
 		Type:      unix.RTN_UNICAST,
 		Type:      unix.RTN_UNICAST,
@@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error {
 			continue
 			continue
 		}
 		}
 
 
+		dr := &net.IPNet{
+			IP:   r.Cidr.Masked().Addr().AsSlice(),
+			Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+		}
+
 		nr := netlink.Route{
 		nr := netlink.Route{
 			LinkIndex: t.deviceIndex,
 			LinkIndex: t.deviceIndex,
-			Dst:       r.Cidr,
+			Dst:       dr,
 			MTU:       r.MTU,
 			MTU:       r.MTU,
 			AdvMSS:    t.advMSS(r),
 			AdvMSS:    t.advMSS(r),
 			Scope:     unix.RT_SCOPE_LINK,
 			Scope:     unix.RT_SCOPE_LINK,
@@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) {
 			continue
 			continue
 		}
 		}
 
 
+		dr := &net.IPNet{
+			IP:   r.Cidr.Masked().Addr().AsSlice(),
+			Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()),
+		}
+
 		nr := netlink.Route{
 		nr := netlink.Route{
 			LinkIndex: t.deviceIndex,
 			LinkIndex: t.deviceIndex,
-			Dst:       r.Cidr,
+			Dst:       dr,
 			MTU:       r.MTU,
 			MTU:       r.MTU,
 			AdvMSS:    t.advMSS(r),
 			AdvMSS:    t.advMSS(r),
 			Scope:     unix.RT_SCOPE_LINK,
 			Scope:     unix.RT_SCOPE_LINK,
@@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) {
 	}
 	}
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 
@@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 		return
 	}
 	}
 
 
-	if !t.cidr.Contains(r.Gw) {
+	//TODO: IPV6-WORK what if not ok?
+	gwAddr, ok := netip.AddrFromSlice(r.Gw)
+	if !ok {
+		t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
+		return
+	}
+
+	gwAddr = gwAddr.Unmap()
+	if !t.cidr.Contains(gwAddr) {
 		// Gateway isn't in our overlay network, ignore
 		// Gateway isn't in our overlay network, ignore
 		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
 		t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
 		return
 		return
@@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 		return
 	}
 	}
 
 
-	newTree := cidr.NewTree4[iputil.VpnIp]()
-	if r.Type == unix.RTM_NEWROUTE {
-		for _, oldR := range t.routeTree.Load().List() {
-			newTree.AddCIDR(oldR.CIDR, oldR.Value)
-		}
+	dstAddr, ok := netip.AddrFromSlice(r.Dst.IP)
+	if !ok {
+		t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address")
+		return
+	}
+
+	ones, _ := r.Dst.Mask.Size()
+	dst := netip.PrefixFrom(dstAddr, ones)
+
+	newTree := t.routeTree.Load().Clone()
 
 
+	if r.Type == unix.RTM_NEWROUTE {
 		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
 		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
-		newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw))
+		newTree.Insert(dst, gwAddr)
 
 
 	} else {
 	} else {
-		gw := iputil.Ip2VpnIp(r.Gw)
-		for _, oldR := range t.routeTree.Load().List() {
-			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
-				// This is the record to delete
-				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
-				continue
-			}
-
-			newTree.AddCIDR(oldR.CIDR, oldR.Value)
-		}
+		newTree.Delete(dst)
+		t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 	}
 	}
-
 	t.routeTree.Store(newTree)
 	t.routeTree.Store(newTree)
 }
 }
 
 

+ 14 - 15
overlay/tun_netbsd.go

@@ -6,7 +6,7 @@ package overlay
 import (
 import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"os/exec"
 	"os/exec"
 	"regexp"
 	"regexp"
@@ -15,10 +15,9 @@ import (
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 )
 )
 
 
@@ -29,10 +28,10 @@ type ifreqDestroy struct {
 
 
 type tun struct {
 type tun struct {
 	Device    string
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
@@ -59,13 +58,13 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
 }
 }
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	// Try to open tun device
 	// Try to open tun device
 	var file *os.File
 	var file *os.File
 	var err error
 	var err error
@@ -109,13 +108,13 @@ func (t *tun) Activate() error {
 	var err error
 	var err error
 
 
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 
@@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (t *tun) addRoutes(logErrors bool) error {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
 	for _, r := range routes {
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")

+ 14 - 15
overlay/tun_openbsd.go

@@ -6,7 +6,7 @@ package overlay
 import (
 import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"os/exec"
 	"os/exec"
 	"regexp"
 	"regexp"
@@ -14,19 +14,18 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 	"syscall"
 	"syscall"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 )
 )
 
 
 type tun struct {
 type tun struct {
 	Device    string
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 
 
 	io.ReadWriteCloser
 	io.ReadWriteCloser
@@ -43,13 +42,13 @@ func (t *tun) Close() error {
 	return nil
 	return nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 	return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD")
 }
 }
 
 
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	deviceName := c.GetString("tun.dev", "")
 	if deviceName == "" {
 	if deviceName == "" {
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
 		return nil, fmt.Errorf("a device name in the format of tunN must be specified")
@@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
 func (t *tun) Activate() error {
 func (t *tun) Activate() error {
 	var err error
 	var err error
 	// TODO use syscalls instead of exec.Command
 	// TODO use syscalls instead of exec.Command
-	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String())
+	cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
@@ -139,7 +138,7 @@ func (t *tun) Activate() error {
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 		return fmt.Errorf("failed to run 'ifconfig': %s", err)
 	}
 	}
 
 
-	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String())
+	cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String())
 	t.l.Debug("command: ", cmd.String())
 	t.l.Debug("command: ", cmd.String())
 	if err = cmd.Run(); err != nil {
 	if err = cmd.Run(); err != nil {
 		return fmt.Errorf("failed to run 'route add': %s", err)
 		return fmt.Errorf("failed to run 'route add': %s", err)
@@ -149,20 +148,20 @@ func (t *tun) Activate() error {
 	return t.addRoutes(false)
 	return t.addRoutes(false)
 }
 }
 
 
-func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
 func (t *tun) addRoutes(logErrors bool) error {
 func (t *tun) addRoutes(logErrors bool) error {
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
 	for _, r := range routes {
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err)
@@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String())
+		cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String())
 		t.l.Debug("command: ", cmd.String())
 		t.l.Debug("command: ", cmd.String())
 		if err := cmd.Run(); err != nil {
 		if err := cmd.Run(); err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
@@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *tun) Cidr() *net.IPNet {
+func (t *tun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 9 - 10
overlay/tun_tester.go

@@ -6,21 +6,20 @@ package overlay
 import (
 import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"sync/atomic"
 	"sync/atomic"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type TestTun struct {
 type TestTun struct {
 	Device    string
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	Routes    []Route
 	Routes    []Route
-	routeTree *cidr.Tree4[iputil.VpnIp]
+	routeTree *bart.Table[netip.Addr]
 	l         *logrus.Logger
 	l         *logrus.Logger
 
 
 	closed    atomic.Bool
 	closed    atomic.Bool
@@ -28,7 +27,7 @@ type TestTun struct {
 	TxPackets chan []byte // Packets transmitted outside by nebula
 	TxPackets chan []byte // Packets transmitted outside by nebula
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) {
 	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
 	_, routes, err := getAllRoutesFromConfig(c, cidr, true)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e
 	}, nil
 	}, nil
 }
 }
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported")
 	return nil, fmt.Errorf("newTunFromFd not supported")
 }
 }
 
 
@@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte {
 // Below this is boilerplate implementation to make nebula actually work
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 //********************************************************************************************************************//
 
 
-func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.MostSpecificContains(ip)
+func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Lookup(ip)
 	return r
 	return r
 }
 }
 
 
@@ -96,7 +95,7 @@ func (t *TestTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *TestTun) Cidr() *net.IPNet {
+func (t *TestTun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 11 - 11
overlay/tun_water_windows.go

@@ -4,30 +4,30 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
+	"net/netip"
 	"os/exec"
 	"os/exec"
 	"strconv"
 	"strconv"
 	"sync/atomic"
 	"sync/atomic"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 	"github.com/songgao/water"
 	"github.com/songgao/water"
 )
 )
 
 
 type waterTun struct {
 type waterTun struct {
 	Device    string
 	Device    string
-	cidr      *net.IPNet
+	cidr      netip.Prefix
 	MTU       int
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 	f         *net.Interface
 	f         *net.Interface
 	*water.Interface
 	*water.Interface
 }
 }
 
 
-func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) {
+func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) {
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	// NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate()
 	t := &waterTun{
 	t := &waterTun{
 		cidr: cidr,
 		cidr: cidr,
@@ -70,8 +70,8 @@ func (t *waterTun) Activate() error {
 		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
 		`C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address",
 		fmt.Sprintf("name=%s", t.Device),
 		fmt.Sprintf("name=%s", t.Device),
 		"source=static",
 		"source=static",
-		fmt.Sprintf("addr=%s", t.cidr.IP),
-		fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)),
+		fmt.Sprintf("addr=%s", t.cidr.Addr()),
+		fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())),
 		"gateway=none",
 		"gateway=none",
 	).Run()
 	).Run()
 	if err != nil {
 	if err != nil {
@@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error {
 	// Path routes
 	// Path routes
 	routes := *t.Routes.Load()
 	routes := *t.Routes.Load()
 	for _, r := range routes {
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
@@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) {
 	}
 	}
 }
 }
 
 
-func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
-func (t *waterTun) Cidr() *net.IPNet {
+func (t *waterTun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 3 - 3
overlay/tun_windows.go

@@ -5,7 +5,7 @@ package overlay
 
 
 import (
 import (
 	"fmt"
 	"fmt"
-	"net"
+	"net/netip"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
@@ -15,11 +15,11 @@ import (
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 )
 )
 
 
-func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) {
+func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) {
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 	return nil, fmt.Errorf("newTunFromFd not supported in Windows")
 }
 }
 
 
-func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) {
+func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) {
 	useWintun := true
 	useWintun := true
 	if err := checkWinTunExists(); err != nil {
 	if err := checkWinTunExists(); err != nil {
 		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")
 		l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver")

+ 12 - 38
overlay/tun_wintun_windows.go

@@ -4,15 +4,13 @@ import (
 	"crypto"
 	"crypto"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
-	"net"
 	"net/netip"
 	"net/netip"
 	"sync/atomic"
 	"sync/atomic"
 	"unsafe"
 	"unsafe"
 
 
+	"github.com/gaissmai/bart"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cidr"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/util"
 	"github.com/slackhq/nebula/wintun"
 	"github.com/slackhq/nebula/wintun"
 	"golang.org/x/sys/windows"
 	"golang.org/x/sys/windows"
@@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
 
 
 type winTun struct {
 type winTun struct {
 	Device    string
 	Device    string
-	cidr      *net.IPNet
-	prefix    netip.Prefix
+	cidr      netip.Prefix
 	MTU       int
 	MTU       int
 	Routes    atomic.Pointer[[]Route]
 	Routes    atomic.Pointer[[]Route]
-	routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
+	routeTree atomic.Pointer[bart.Table[netip.Addr]]
 	l         *logrus.Logger
 	l         *logrus.Logger
 
 
 	tun *wintun.NativeTun
 	tun *wintun.NativeTun
@@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) {
 	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 	return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil
 }
 }
 
 
-func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) {
+func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) {
 	deviceName := c.GetString("tun.dev", "")
 	deviceName := c.GetString("tun.dev", "")
 	guid, err := generateGUIDByDeviceName(deviceName)
 	guid, err := generateGUIDByDeviceName(deviceName)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
 		return nil, fmt.Errorf("generate GUID failed: %w", err)
 	}
 	}
 
 
-	prefix, err := iputil.ToNetIpPrefix(*cidr)
-	if err != nil {
-		return nil, err
-	}
-
 	t := &winTun{
 	t := &winTun{
 		Device: deviceName,
 		Device: deviceName,
 		cidr:   cidr,
 		cidr:   cidr,
-		prefix: prefix,
 		MTU:    c.GetInt("tun.mtu", DefaultMTU),
 		MTU:    c.GetInt("tun.mtu", DefaultMTU),
 		l:      l,
 		l:      l,
 	}
 	}
@@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error {
 func (t *winTun) Activate() error {
 func (t *winTun) Activate() error {
 	luid := winipcfg.LUID(t.tun.LUID())
 	luid := winipcfg.LUID(t.tun.LUID())
 
 
-	err := luid.SetIPAddresses([]netip.Prefix{t.prefix})
+	err := luid.SetIPAddresses([]netip.Prefix{t.cidr})
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("failed to set address: %w", err)
 		return fmt.Errorf("failed to set address: %w", err)
 	}
 	}
@@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error {
 	foundDefault4 := false
 	foundDefault4 := false
 
 
 	for _, r := range routes {
 	for _, r := range routes {
-		if r.Via == nil || !r.Install {
+		if !r.Via.IsValid() || !r.Install {
 			// We don't allow route MTUs so only install routes with a via
 			// We don't allow route MTUs so only install routes with a via
 			continue
 			continue
 		}
 		}
 
 
-		prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
-		if err != nil {
-			retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err)
-			if logErrors {
-				retErr.Log(t.l)
-				continue
-			} else {
-				return retErr
-			}
-		}
-
 		// Add our unsafe route
 		// Add our unsafe route
-		err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric))
+		err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
 		if err != nil {
 		if err != nil {
 			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
 			retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
 			if logErrors {
 			if logErrors {
@@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error {
 		}
 		}
 
 
 		if !foundDefault4 {
 		if !foundDefault4 {
-			if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 {
+			if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 {
 				foundDefault4 = true
 				foundDefault4 = true
 			}
 			}
 		}
 		}
@@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
 			continue
 			continue
 		}
 		}
 
 
-		prefix, err := iputil.ToNetIpPrefix(*r.Cidr)
-		if err != nil {
-			t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix")
-			continue
-		}
-
-		err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr())
+		err := luid.DeleteRoute(r.Cidr, r.Via)
 		if err != nil {
 		if err != nil {
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 			t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
 		} else {
 		} else {
@@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error {
 	return nil
 	return nil
 }
 }
 
 
-func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	_, r := t.routeTree.Load().MostSpecificContains(ip)
+func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
+	r, _ := t.routeTree.Load().Lookup(ip)
 	return r
 	return r
 }
 }
 
 
-func (t *winTun) Cidr() *net.IPNet {
+func (t *winTun) Cidr() netip.Prefix {
 	return t.cidr
 	return t.cidr
 }
 }
 
 

+ 7 - 8
overlay/user.go

@@ -2,18 +2,17 @@ package overlay
 
 
 import (
 import (
 	"io"
 	"io"
-	"net"
+	"net/netip"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
-func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) {
+func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) {
 	return NewUserDevice(tunCidr)
 	return NewUserDevice(tunCidr)
 }
 }
 
 
-func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
+func NewUserDevice(tunCidr netip.Prefix) (Device, error) {
 	// these pipes guarantee each write/read will match 1:1
 	// these pipes guarantee each write/read will match 1:1
 	or, ow := io.Pipe()
 	or, ow := io.Pipe()
 	ir, iw := io.Pipe()
 	ir, iw := io.Pipe()
@@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) {
 }
 }
 
 
 type UserDevice struct {
 type UserDevice struct {
-	tunCidr *net.IPNet
+	tunCidr netip.Prefix
 
 
 	outboundReader *io.PipeReader
 	outboundReader *io.PipeReader
 	outboundWriter *io.PipeWriter
 	outboundWriter *io.PipeWriter
@@ -39,9 +38,9 @@ type UserDevice struct {
 func (d *UserDevice) Activate() error {
 func (d *UserDevice) Activate() error {
 	return nil
 	return nil
 }
 }
-func (d *UserDevice) Cidr() *net.IPNet                      { return d.tunCidr }
-func (d *UserDevice) Name() string                          { return "faketun0" }
-func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip }
+func (d *UserDevice) Cidr() netip.Prefix                { return d.tunCidr }
+func (d *UserDevice) Name() string                      { return "faketun0" }
+func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 	return d, nil
 	return d, nil
 }
 }

+ 2 - 0
pki.go

@@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
 	}
 	}
 
 
 	if !initial {
 	if !initial {
+		//TODO: include check for mask equality as well
+
 		// did IP in cert change? if so, don't set
 		// did IP in cert change? if so, don't set
 		currentCert := p.cs.Load().Certificate
 		currentCert := p.cs.Load().Certificate
 		oldIPs := currentCert.Details.Ips
 		oldIPs := currentCert.Details.Ips

+ 52 - 31
relay_manager.go

@@ -2,14 +2,15 @@ package nebula
 
 
 import (
 import (
 	"context"
 	"context"
+	"encoding/binary"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"net/netip"
 	"sync/atomic"
 	"sync/atomic"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type relayManager struct {
 type relayManager struct {
@@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) {
 
 
 // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
 // AddRelay finds an available relay index on the hostmap, and associates the relay info with it.
 // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
 // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp.
-func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) {
+func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) {
 	hm.Lock()
 	hm.Lock()
 	defer hm.Unlock()
 	defer hm.Unlock()
 	for i := 0; i < 32; i++ {
 	for i := 0; i < 32; i++ {
@@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter
 
 
 func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
 func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) {
 	rm.l.WithFields(logrus.Fields{
 	rm.l.WithFields(logrus.Fields{
-		"relayFrom":           iputil.VpnIp(m.RelayFromIp),
-		"relayTo":             iputil.VpnIp(m.RelayToIp),
+		"relayFrom":           m.RelayFromIp,
+		"relayTo":             m.RelayToIp,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
 		"vpnIp":               h.vpnIp}).
 		"vpnIp":               h.vpnIp}).
 		Info("handleCreateRelayResponse")
 		Info("handleCreateRelayResponse")
-	target := iputil.VpnIp(m.RelayToIp)
+	target := m.RelayToIp
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+	targetAddr := netip.AddrFrom4(b)
 
 
 	relay, err := rm.EstablishRelay(h, m)
 	relay, err := rm.EstablishRelay(h, m)
 	if err != nil {
 	if err != nil {
@@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
 		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
 		return
 		return
 	}
 	}
-	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
+	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr)
 	if !ok {
 	if !ok {
 		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
 		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
 		return
 		return
 	}
 	}
 	if peerRelay.State == PeerRequested {
 	if peerRelay.State == PeerRequested {
+		//TODO: IPV6-WORK
+		b = peerHostInfo.vpnIp.As4()
 		peerRelay.State = Established
 		peerRelay.State = Established
 		resp := NebulaControl{
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: peerRelay.LocalIndex,
 			ResponderRelayIndex: peerRelay.LocalIndex,
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
-			RelayFromIp:         uint32(peerHostInfo.vpnIp),
+			RelayFromIp:         binary.BigEndian.Uint32(b[:]),
 			RelayToIp:           uint32(target),
 			RelayToIp:           uint32(target),
 		}
 		}
 		msg, err := resp.Marshal()
 		msg, err := resp.Marshal()
@@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		} else {
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-				"relayTo":             iputil.VpnIp(resp.RelayToIp),
+				"relayFrom":           resp.RelayFromIp,
+				"relayTo":             resp.RelayToIp,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"vpnIp":               peerHostInfo.vpnIp}).
 				"vpnIp":               peerHostInfo.vpnIp}).
@@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 }
 }
 
 
 func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
 func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) {
+	//TODO: IPV6-WORK
+	b := [4]byte{}
+	binary.BigEndian.PutUint32(b[:], m.RelayFromIp)
+	from := netip.AddrFrom4(b)
 
 
-	from := iputil.VpnIp(m.RelayFromIp)
-	target := iputil.VpnIp(m.RelayToIp)
+	binary.BigEndian.PutUint32(b[:], m.RelayToIp)
+	target := netip.AddrFrom4(b)
 
 
 	logMsg := rm.l.WithFields(logrus.Fields{
 	logMsg := rm.l.WithFields(logrus.Fields{
 		"relayFrom":           from,
 		"relayFrom":           from,
@@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 	logMsg.Info("handleCreateRelayRequest")
 	logMsg.Info("handleCreateRelayRequest")
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// Is the source of the relay me? This should never happen, but did happen due to
 	// an issue migrating relays over to newly re-handshaked host info objects.
 	// an issue migrating relays over to newly re-handshaked host info objects.
-	if from == f.myVpnIp {
-		logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself")
+	if from == f.myVpnNet.Addr() {
+		logMsg.WithField("myIP", from).Error("Discarding relay request from myself")
 		return
 		return
 	}
 	}
 	// Is the target of the relay me?
 	// Is the target of the relay me?
-	if target == f.myVpnIp {
+	if target == f.myVpnNet.Addr() {
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		existingRelay, ok := h.relayState.QueryRelayForByIp(from)
 		if ok {
 		if ok {
 			switch existingRelay.State {
 			switch existingRelay.State {
@@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			return
 			return
 		}
 		}
 
 
+		//TODO: IPV6-WORK
+		fromB := from.As4()
+		targetB := target.As4()
+
 		resp := NebulaControl{
 		resp := NebulaControl{
 			Type:                NebulaControl_CreateRelayResponse,
 			Type:                NebulaControl_CreateRelayResponse,
 			ResponderRelayIndex: relay.LocalIndex,
 			ResponderRelayIndex: relay.LocalIndex,
 			InitiatorRelayIndex: relay.RemoteIndex,
 			InitiatorRelayIndex: relay.RemoteIndex,
-			RelayFromIp:         uint32(from),
-			RelayToIp:           uint32(target),
+			RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
+			RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 		}
 		}
 		msg, err := resp.Marshal()
 		msg, err := resp.Marshal()
 		if err != nil {
 		if err != nil {
@@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		} else {
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 			f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-				"relayTo":             iputil.VpnIp(resp.RelayToIp),
+				//TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now
+				"relayFrom":           from,
+				"relayTo":             target,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"vpnIp":               h.vpnIp}).
 				"vpnIp":               h.vpnIp}).
@@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			f.Handshake(target)
 			f.Handshake(target)
 			return
 			return
 		}
 		}
-		if peer.remote == nil {
+		if !peer.remote.IsValid() {
 			// Only create relays to peers for whom I have a direct connection
 			// Only create relays to peers for whom I have a direct connection
 			return
 			return
 		}
 		}
@@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			sendCreateRequest = true
 			sendCreateRequest = true
 		}
 		}
 		if sendCreateRequest {
 		if sendCreateRequest {
+			//TODO: IPV6-WORK
+			fromB := h.vpnIp.As4()
+			targetB := target.As4()
+
 			// Send a CreateRelayRequest to the peer.
 			// Send a CreateRelayRequest to the peer.
 			req := NebulaControl{
 			req := NebulaControl{
 				Type:                NebulaControl_CreateRelayRequest,
 				Type:                NebulaControl_CreateRelayRequest,
 				InitiatorRelayIndex: index,
 				InitiatorRelayIndex: index,
-				RelayFromIp:         uint32(h.vpnIp),
-				RelayToIp:           uint32(target),
+				RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
+				RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 			}
 			}
 			msg, err := req.Marshal()
 			msg, err := req.Marshal()
 			if err != nil {
 			if err != nil {
@@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			} else {
 			} else {
 				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				rm.l.WithFields(logrus.Fields{
 				rm.l.WithFields(logrus.Fields{
-					"relayFrom":           iputil.VpnIp(req.RelayFromIp),
-					"relayTo":             iputil.VpnIp(req.RelayToIp),
+					//TODO: IPV6-WORK another lazy used to use the req object
+					"relayFrom":           h.vpnIp,
+					"relayTo":             target,
 					"initiatorRelayIndex": req.InitiatorRelayIndex,
 					"initiatorRelayIndex": req.InitiatorRelayIndex,
 					"responderRelayIndex": req.ResponderRelayIndex,
 					"responderRelayIndex": req.ResponderRelayIndex,
 					"vpnIp":               target}).
 					"vpnIp":               target}).
@@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 						"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
 						"existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest")
 					return
 					return
 				}
 				}
+				//TODO: IPV6-WORK
+				fromB := h.vpnIp.As4()
+				targetB := target.As4()
 				resp := NebulaControl{
 				resp := NebulaControl{
 					Type:                NebulaControl_CreateRelayResponse,
 					Type:                NebulaControl_CreateRelayResponse,
 					ResponderRelayIndex: relay.LocalIndex,
 					ResponderRelayIndex: relay.LocalIndex,
 					InitiatorRelayIndex: relay.RemoteIndex,
 					InitiatorRelayIndex: relay.RemoteIndex,
-					RelayFromIp:         uint32(h.vpnIp),
-					RelayToIp:           uint32(target),
+					RelayFromIp:         binary.BigEndian.Uint32(fromB[:]),
+					RelayToIp:           binary.BigEndian.Uint32(targetB[:]),
 				}
 				}
 				msg, err := resp.Marshal()
 				msg, err := resp.Marshal()
 				if err != nil {
 				if err != nil {
@@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 				} else {
 				} else {
 					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					rm.l.WithFields(logrus.Fields{
 					rm.l.WithFields(logrus.Fields{
-						"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-						"relayTo":             iputil.VpnIp(resp.RelayToIp),
+						//TODO: IPV6-WORK more lazy, used to use resp object
+						"relayFrom":           h.vpnIp,
+						"relayTo":             target,
 						"initiatorRelayIndex": resp.InitiatorRelayIndex,
 						"initiatorRelayIndex": resp.InitiatorRelayIndex,
 						"responderRelayIndex": resp.ResponderRelayIndex,
 						"responderRelayIndex": resp.ResponderRelayIndex,
 						"vpnIp":               h.vpnIp}).
 						"vpnIp":               h.vpnIp}).
@@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		}
 		}
 	}
 	}
 }
 }
-
-func (rm *relayManager) RemoveRelay(localIdx uint32) {
-	rm.hostmap.RemoveRelay(localIdx)
-}

+ 73 - 93
remote_list.go

@@ -1,7 +1,6 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"bytes"
 	"context"
 	"context"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
@@ -12,16 +11,14 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/iputil"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 // forEachFunc is used to benefit folks that want to do work inside the lock
 // forEachFunc is used to benefit folks that want to do work inside the lock
-type forEachFunc func(addr *udp.Addr, preferred bool)
+type forEachFunc func(addr netip.AddrPort, preferred bool)
 
 
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
 // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
-type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool
-type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool
+type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool
+type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool
 
 
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // CacheMap is a struct that better represents the lighthouse cache for humans
 // The string key is the owners vpnIp
 // The string key is the owners vpnIp
@@ -30,9 +27,9 @@ type CacheMap map[string]*Cache
 // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
 // Cache is the other part of CacheMap to better represent the lighthouse cache for humans
 // We don't reason about ipv4 vs ipv6 here
 // We don't reason about ipv4 vs ipv6 here
 type Cache struct {
 type Cache struct {
-	Learned  []*udp.Addr `json:"learned,omitempty"`
-	Reported []*udp.Addr `json:"reported,omitempty"`
-	Relay    []*net.IP   `json:"relay"`
+	Learned  []netip.AddrPort `json:"learned,omitempty"`
+	Reported []netip.AddrPort `json:"reported,omitempty"`
+	Relay    []netip.Addr     `json:"relay"`
 }
 }
 
 
 //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
 //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
@@ -46,7 +43,7 @@ type cache struct {
 }
 }
 
 
 type cacheRelay struct {
 type cacheRelay struct {
-	relay []uint32
+	relay []netip.Addr
 }
 }
 
 
 // cacheV4 stores learned and reported ipv4 records under cache
 // cacheV4 stores learned and reported ipv4 records under cache
@@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration,
 						continue
 						continue
 					}
 					}
 					for _, a := range addrs {
 					for _, a := range addrs {
-						netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{}
+						netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{}
 					}
 					}
 				}
 				}
 				origSet := r.ips.Load()
 				origSet := r.ips.Load()
@@ -193,22 +190,22 @@ type RemoteList struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
 	// A deduplicated set of addresses. Any accessor should lock beforehand.
-	addrs []*udp.Addr
+	addrs []netip.AddrPort
 
 
 	// A set of relay addresses. VpnIp addresses that the remote identified as relays.
 	// A set of relay addresses. VpnIp addresses that the remote identified as relays.
-	relays []*iputil.VpnIp
+	relays []netip.Addr
 
 
 	// These are maps to store v4 and v6 addresses per lighthouse
 	// These are maps to store v4 and v6 addresses per lighthouse
 	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
 	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
 	// For learned addresses, this is the vpnIp that sent the packet
 	// For learned addresses, this is the vpnIp that sent the packet
-	cache map[iputil.VpnIp]*cache
+	cache map[netip.Addr]*cache
 
 
 	hr        *hostnamesResults
 	hr        *hostnamesResults
 	shouldAdd func(netip.Addr) bool
 	shouldAdd func(netip.Addr) bool
 
 
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
 	// They should not be tried again during a handshake
 	// They should not be tried again during a handshake
-	badRemotes []*udp.Addr
+	badRemotes []netip.AddrPort
 
 
 	// A flag that the cache may have changed and addrs needs to be rebuilt
 	// A flag that the cache may have changed and addrs needs to be rebuilt
 	shouldRebuild bool
 	shouldRebuild bool
@@ -217,9 +214,9 @@ type RemoteList struct {
 // NewRemoteList creates a new empty RemoteList
 // NewRemoteList creates a new empty RemoteList
 func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
 func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList {
 	return &RemoteList{
 	return &RemoteList{
-		addrs:     make([]*udp.Addr, 0),
-		relays:    make([]*iputil.VpnIp, 0),
-		cache:     make(map[iputil.VpnIp]*cache),
+		addrs:     make([]netip.AddrPort, 0),
+		relays:    make([]netip.Addr, 0),
+		cache:     make(map[netip.Addr]*cache),
 		shouldAdd: shouldAdd,
 		shouldAdd: shouldAdd,
 	}
 	}
 }
 }
@@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) {
 
 
 // Len locks and reports the size of the deduplicated address list
 // Len locks and reports the size of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
+func (r *RemoteList) Len(preferredRanges []netip.Prefix) int {
 	r.Rebuild(preferredRanges)
 	r.Rebuild(preferredRanges)
 	r.RLock()
 	r.RLock()
 	defer r.RUnlock()
 	defer r.RUnlock()
@@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
 
 
 // ForEach locks and will call the forEachFunc for every deduplicated address in the list
 // ForEach locks and will call the forEachFunc for every deduplicated address in the list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
+func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) {
 	r.Rebuild(preferredRanges)
 	r.Rebuild(preferredRanges)
 	r.RLock()
 	r.RLock()
 	for _, v := range r.addrs {
 	for _, v := range r.addrs {
-		forEach(v, isPreferred(v.IP, preferredRanges))
+		forEach(v, isPreferred(v.Addr(), preferredRanges))
 	}
 	}
 	r.RUnlock()
 	r.RUnlock()
 }
 }
 
 
 // CopyAddrs locks and makes a deep copy of the deduplicated address list
 // CopyAddrs locks and makes a deep copy of the deduplicated address list
 // The deduplication work may need to occur here, so you must pass preferredRanges
 // The deduplication work may need to occur here, so you must pass preferredRanges
-func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
+func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort {
 	if r == nil {
 	if r == nil {
 		return nil
 		return nil
 	}
 	}
@@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
 
 
 	r.RLock()
 	r.RLock()
 	defer r.RUnlock()
 	defer r.RUnlock()
-	c := make([]*udp.Addr, len(r.addrs))
+	c := make([]netip.AddrPort, len(r.addrs))
 	for i, v := range r.addrs {
 	for i, v := range r.addrs {
-		c[i] = v.Copy()
+		c[i] = v
 	}
 	}
 	return c
 	return c
 }
 }
@@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr {
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
 // It will mark the deduplicated address list as dirty, so do not call it unless new information is available
 // TODO: this needs to support the allow list list
 // TODO: this needs to support the allow list list
-func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) {
+func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) {
 	r.Lock()
 	r.Lock()
 	defer r.Unlock()
 	defer r.Unlock()
-	if v4 := addr.IP.To4(); v4 != nil {
-		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
+	if remote.Addr().Is4() {
+		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port()))
 	} else {
 	} else {
-		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
+		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port()))
 	}
 	}
 }
 }
 
 
@@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap {
 		c := cm[vpnIp]
 		c := cm[vpnIp]
 		if c == nil {
 		if c == nil {
 			c = &Cache{
 			c = &Cache{
-				Learned:  make([]*udp.Addr, 0),
-				Reported: make([]*udp.Addr, 0),
-				Relay:    make([]*net.IP, 0),
+				Learned:  make([]netip.AddrPort, 0),
+				Reported: make([]netip.AddrPort, 0),
+				Relay:    make([]netip.Addr, 0),
 			}
 			}
 			cm[vpnIp] = c
 			cm[vpnIp] = c
 		}
 		}
@@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap {
 
 
 		if mc.v4 != nil {
 		if mc.v4 != nil {
 			if mc.v4.learned != nil {
 			if mc.v4.learned != nil {
-				c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
+				c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned))
 			}
 			}
 
 
 			for _, a := range mc.v4.reported {
 			for _, a := range mc.v4.reported {
-				c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
+				c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a))
 			}
 			}
 		}
 		}
 
 
 		if mc.v6 != nil {
 		if mc.v6 != nil {
 			if mc.v6.learned != nil {
 			if mc.v6.learned != nil {
-				c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
+				c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned))
 			}
 			}
 
 
 			for _, a := range mc.v6.reported {
 			for _, a := range mc.v6.reported {
-				c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
+				c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a))
 			}
 			}
 		}
 		}
 
 
 		if mc.relay != nil {
 		if mc.relay != nil {
 			for _, a := range mc.relay.relay {
 			for _, a := range mc.relay.relay {
-				nip := iputil.VpnIp(a).ToIP()
-				c.Relay = append(c.Relay, &nip)
+				c.Relay = append(c.Relay, a)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap {
 }
 }
 
 
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad *udp.Addr) {
-	if bad == nil {
+func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
+	if !bad.IsValid() {
 		// relays can have nil udp Addrs
 		// relays can have nil udp Addrs
 		return
 		return
 	}
 	}
@@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) {
 	}
 	}
 
 
 	// We copy here because we are taking something else's memory and we can't trust everything
 	// We copy here because we are taking something else's memory and we can't trust everything
-	r.badRemotes = append(r.badRemotes, bad.Copy())
+	r.badRemotes = append(r.badRemotes, bad)
 
 
 	// Mark the next interaction must recollect/dedupe
 	// Mark the next interaction must recollect/dedupe
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 }
 }
 
 
 // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
 // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
-func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr {
+func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort {
 	r.RLock()
 	r.RLock()
 	defer r.RUnlock()
 	defer r.RUnlock()
 
 
-	c := make([]*udp.Addr, len(r.badRemotes))
+	c := make([]netip.AddrPort, len(r.badRemotes))
 	for i, v := range r.badRemotes {
 	for i, v := range r.badRemotes {
-		c[i] = v.Copy()
+		c[i] = v
 	}
 	}
 	return c
 	return c
 }
 }
@@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() {
 
 
 // Rebuild locks and generates the deduplicated address list only if there is work to be done
 // Rebuild locks and generates the deduplicated address list only if there is work to be done
 // There is generally no reason to call this directly but it is safe to do so
 // There is generally no reason to call this directly but it is safe to do so
-func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
+func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) {
 	r.Lock()
 	r.Lock()
 	defer r.Unlock()
 	defer r.Unlock()
 
 
@@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
 }
 }
 
 
 // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
 // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
-func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
+func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool {
 	for _, v := range r.badRemotes {
 	for _, v := range r.badRemotes {
-		if v.Equals(remote) {
+		if v == remote {
 			return true
 			return true
 		}
 		}
 	}
 	}
@@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool {
 
 
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
 }
 }
 
 
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) {
+func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 
@@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
 	}
 	}
 }
 }
 
 
-func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) {
+func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeRelay(ownerVpnIp)
 	c := r.unlockedGetOrMakeRelay(ownerVpnIp)
 
 
@@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI
 
 
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) {
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 	c := r.unlockedGetOrMakeV4(ownerVpnIp)
 
 
@@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort)
 
 
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
 // deduplicated address list as dirty
 // deduplicated address list as dirty
-func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
 }
 }
 
 
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
 // and marks the deduplicated address list as dirty
 // and marks the deduplicated address list as dirty
-func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) {
+func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 
@@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp,
 
 
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
 // This is only useful for establishing static hosts
 // This is only useful for establishing static hosts
-func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) {
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) {
 	r.shouldRebuild = true
 	r.shouldRebuild = true
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 	c := r.unlockedGetOrMakeV6(ownerVpnIp)
 
 
@@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort)
 	}
 	}
 }
 }
 
 
-func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay {
+func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay {
 	am := r.cache[ownerVpnIp]
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 	if am == nil {
 		am = &cache{}
 		am = &cache{}
@@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay
 
 
 // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
 // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
 // The caller must dirty the learned address cache if required
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 {
 	am := r.cache[ownerVpnIp]
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 	if am == nil {
 		am = &cache{}
 		am = &cache{}
@@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 {
 
 
 // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
 // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
 // The caller must dirty the learned address cache if required
 // The caller must dirty the learned address cache if required
-func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 {
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 {
 	am := r.cache[ownerVpnIp]
 	am := r.cache[ownerVpnIp]
 	if am == nil {
 	if am == nil {
 		am = &cache{}
 		am = &cache{}
@@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() {
 	for _, c := range r.cache {
 	for _, c := range r.cache {
 		if c.v4 != nil {
 		if c.v4 != nil {
 			if c.v4.learned != nil {
 			if c.v4.learned != nil {
-				u := NewUDPAddrFromLH4(c.v4.learned)
+				u := AddrPortFromIp4AndPort(c.v4.learned)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
 			}
 			}
 
 
 			for _, v := range c.v4.reported {
 			for _, v := range c.v4.reported {
-				u := NewUDPAddrFromLH4(v)
+				u := AddrPortFromIp4AndPort(v)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
@@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() {
 
 
 		if c.v6 != nil {
 		if c.v6 != nil {
 			if c.v6.learned != nil {
 			if c.v6.learned != nil {
-				u := NewUDPAddrFromLH6(c.v6.learned)
+				u := AddrPortFromIp6AndPort(c.v6.learned)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
 			}
 			}
 
 
 			for _, v := range c.v6.reported {
 			for _, v := range c.v6.reported {
-				u := NewUDPAddrFromLH6(v)
+				u := AddrPortFromIp6AndPort(v)
 				if !r.unlockedIsBad(u) {
 				if !r.unlockedIsBad(u) {
 					addrs = append(addrs, u)
 					addrs = append(addrs, u)
 				}
 				}
@@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() {
 
 
 		if c.relay != nil {
 		if c.relay != nil {
 			for _, v := range c.relay.relay {
 			for _, v := range c.relay.relay {
-				ip := iputil.VpnIp(v)
-				relays = append(relays, &ip)
+				relays = append(relays, v)
 			}
 			}
 		}
 		}
 	}
 	}
@@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() {
 	dnsAddrs := r.hr.GetIPs()
 	dnsAddrs := r.hr.GetIPs()
 	for _, addr := range dnsAddrs {
 	for _, addr := range dnsAddrs {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
 		if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) {
-			v6 := addr.Addr().As16()
-			addrs = append(addrs, &udp.Addr{
-				IP:   v6[:],
-				Port: addr.Port(),
-			})
+			addrs = append(addrs, addr)
 		}
 		}
 	}
 	}
 
 
@@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() {
 }
 }
 
 
 // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
 // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
-func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
+func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) {
 	n := len(r.addrs)
 	n := len(r.addrs)
 	if n < 2 {
 	if n < 2 {
 		return
 		return
@@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 		b := r.addrs[j]
 		b := r.addrs[j]
 		// Preferred addresses first
 		// Preferred addresses first
 
 
-		aPref := isPreferred(a.IP, preferredRanges)
-		bPref := isPreferred(b.IP, preferredRanges)
+		aPref := isPreferred(a.Addr(), preferredRanges)
+		bPref := isPreferred(b.Addr(), preferredRanges)
 		switch {
 		switch {
 		case aPref && !bPref:
 		case aPref && !bPref:
 			// If i is preferred and j is not, i is less than j
 			// If i is preferred and j is not, i is less than j
@@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 		}
 		}
 
 
 		// ipv6 addresses 2nd
 		// ipv6 addresses 2nd
-		a4 := a.IP.To4()
-		b4 := b.IP.To4()
+		a4 := a.Addr().Is4()
+		b4 := b.Addr().Is4()
 		switch {
 		switch {
-		case a4 == nil && b4 != nil:
+		case a4 == false && b4 == true:
 			// If i is v6 and j is v4, i is less than j
 			// If i is v6 and j is v4, i is less than j
 			return true
 			return true
 
 
-		case a4 != nil && b4 == nil:
+		case a4 == true && b4 == false:
 			// If j is v6 and i is v4, i is not less than j
 			// If j is v6 and i is v4, i is not less than j
 			return false
 			return false
 
 
-		case a4 != nil && b4 != nil:
-			// Special case for ipv4, a4 and b4 are not nil
-			aPrivate := isPrivateIP(a4)
-			bPrivate := isPrivateIP(b4)
+		case a4 == true && b4 == true:
+			// i and j are both ipv4
+			aPrivate := a.Addr().IsPrivate()
+			bPrivate := b.Addr().IsPrivate()
 			switch {
 			switch {
 			case !aPrivate && bPrivate:
 			case !aPrivate && bPrivate:
 				// If i is a public ip (not private) and j is a private ip, i is less then j
 				// If i is a public ip (not private) and j is a private ip, i is less then j
@@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 		}
 		}
 
 
 		// lexical order of ips 3rd
 		// lexical order of ips 3rd
-		c := bytes.Compare(a.IP, b.IP)
+		c := a.Addr().Compare(b.Addr())
 		if c == 0 {
 		if c == 0 {
 			// Ips are the same, Lexical order of ports 4th
 			// Ips are the same, Lexical order of ports 4th
-			return a.Port < b.Port
+			return a.Port() < b.Port()
 		}
 		}
 
 
 		// Ip wasn't the same
 		// Ip wasn't the same
@@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
 	// Deduplicate
 	// Deduplicate
 	a, b := 0, 1
 	a, b := 0, 1
 	for b < n {
 	for b < n {
-		if !r.addrs[a].Equals(r.addrs[b]) {
+		if r.addrs[a] != r.addrs[b] {
 			a++
 			a++
 			if a != b {
 			if a != b {
 				r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
 				r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
@@ -693,7 +684,7 @@ func minInt(a, b int) int {
 }
 }
 
 
 // isPreferred returns true of the ip is contained in the preferredRanges list
 // isPreferred returns true of the ip is contained in the preferredRanges list
-func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
+func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool {
 	//TODO: this would be better in a CIDR6Tree
 	//TODO: this would be better in a CIDR6Tree
 	for _, p := range preferredRanges {
 	for _, p := range preferredRanges {
 		if p.Contains(ip) {
 		if p.Contains(ip) {
@@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
 	}
 	}
 	return false
 	return false
 }
 }
-
-var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
-var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
-var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
-
-// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
-func isPrivateIP(ip net.IP) bool {
-	//TODO: another great cidrtree option
-	//TODO: Private for ipv6 or just let it ride?
-	return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
-}

+ 98 - 89
remote_list_test.go

@@ -1,47 +1,47 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"net"
+	"encoding/binary"
+	"net/netip"
 	"testing"
 	"testing"
 
 
-	"github.com/slackhq/nebula/iputil"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestRemoteList_Rebuild(t *testing.T) {
 func TestRemoteList_Rebuild(t *testing.T) {
 	rl := NewRemoteList(nil)
 	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1475"), // this is duped
+			newIp4AndPortFromString("172.17.0.182:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"), // this is duped
+			newIp4AndPortFromString("172.18.0.1:10101"), // this is duped
+			newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe
+			newIp4AndPortFromString("172.19.0.1:10101"),
+			newIp4AndPortFromString("172.31.0.1:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port
+			newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe
 		},
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
-		1,
-		1,
+		netip.MustParseAddr("0.0.0.1"),
+		netip.MustParseAddr("0.0.0.1"),
 		[]*Ip6AndPort{
 		[]*Ip6AndPort{
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
-			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
+			newIp6AndPortFromString("[1::1]:1"), // this is duped
+			newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped
+			newIp6AndPortFromString("[1:100::1]:1"),
+			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
+			newIp6AndPortFromString("[1::1]:2"), // this is a dupe
 		},
 		},
-		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *Ip6AndPort) bool { return true },
 	)
 	)
 
 
-	rl.Rebuild([]*net.IPNet{})
+	rl.Rebuild([]netip.Prefix{})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 
 
 	// ipv6 first, sorted lexically within
 	// ipv6 first, sorted lexically within
@@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
 	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
 
 
 	// Now ensure we can hoist ipv4 up
 	// Now ensure we can hoist ipv4 up
-	_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
-	assert.NoError(t, err)
-	rl.Rebuild([]*net.IPNet{ipNet})
+	rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 
 
 	// ipv4 first, public then private, lexically within them
 	// ipv4 first, public then private, lexically within them
@@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) {
 	assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
 	assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
 
 
 	// Ensure we can hoist a specific ipv4 range over anything else
 	// Ensure we can hoist a specific ipv4 range over anything else
-	_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
-	assert.NoError(t, err)
-	rl.Rebuild([]*net.IPNet{ipNet})
+	rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")})
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
 
 
 	// Preferred ipv4 first
 	// Preferred ipv4 first
@@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) {
 func BenchmarkFullRebuild(b *testing.B) {
 func BenchmarkFullRebuild(b *testing.B) {
 	rl := NewRemoteList(nil)
 	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+			newIp4AndPortFromString("70.199.182.92:1475"),
+			newIp4AndPortFromString("172.17.0.182:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),
+			newIp4AndPortFromString("172.18.0.1:10101"),
+			newIp4AndPortFromString("172.19.0.1:10101"),
+			newIp4AndPortFromString("172.31.0.1:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip6AndPort{
 		[]*Ip6AndPort{
-			NewIp6AndPort(net.ParseIP("1::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
-			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+			newIp6AndPortFromString("[1::1]:1"),
+			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+			newIp6AndPortFromString("[1:100::1]:1"),
+			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
 		},
-		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *Ip6AndPort) bool { return true },
 	)
 	)
 
 
 	b.Run("no preferred", func(b *testing.B) {
 	b.Run("no preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{})
+			rl.Rebuild([]netip.Prefix{})
 		}
 		}
 	})
 	})
 
 
-	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
-	assert.NoError(b, err)
+	ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
 	b.Run("1 preferred", func(b *testing.B) {
 	b.Run("1 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{ipNet})
+			rl.Rebuild([]netip.Prefix{ipNet1})
 		}
 		}
 	})
 	})
 
 
-	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
-	assert.NoError(b, err)
+	ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
 	b.Run("2 preferred", func(b *testing.B) {
 	b.Run("2 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+			rl.Rebuild([]netip.Prefix{ipNet2})
 		}
 		}
 	})
 	})
 
 
-	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
-	assert.NoError(b, err)
+	ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
 	b.Run("3 preferred", func(b *testing.B) {
 	b.Run("3 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+			rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
 		}
 		}
 	})
 	})
 }
 }
@@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) {
 func BenchmarkSortRebuild(b *testing.B) {
 func BenchmarkSortRebuild(b *testing.B) {
 	rl := NewRemoteList(nil)
 	rl := NewRemoteList(nil)
 	rl.unlockedSetV4(
 	rl.unlockedSetV4(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip4AndPort{
 		[]*Ip4AndPort{
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101},
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101},   // this is a dupe
-			{Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port
+			newIp4AndPortFromString("70.199.182.92:1475"),
+			newIp4AndPortFromString("172.17.0.182:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),
+			newIp4AndPortFromString("172.18.0.1:10101"),
+			newIp4AndPortFromString("172.19.0.1:10101"),
+			newIp4AndPortFromString("172.31.0.1:10101"),
+			newIp4AndPortFromString("172.17.1.1:10101"),   // this is a dupe
+			newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port
 		},
 		},
-		func(iputil.VpnIp, *Ip4AndPort) bool { return true },
+		func(netip.Addr, *Ip4AndPort) bool { return true },
 	)
 	)
 
 
 	rl.unlockedSetV6(
 	rl.unlockedSetV6(
-		0,
-		0,
+		netip.MustParseAddr("0.0.0.0"),
+		netip.MustParseAddr("0.0.0.0"),
 		[]*Ip6AndPort{
 		[]*Ip6AndPort{
-			NewIp6AndPort(net.ParseIP("1::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
-			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
-			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+			newIp6AndPortFromString("[1::1]:1"),
+			newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port
+			newIp6AndPortFromString("[1:100::1]:1"),
+			newIp6AndPortFromString("[1::1]:1"), // this is a dupe
 		},
 		},
-		func(iputil.VpnIp, *Ip6AndPort) bool { return true },
+		func(netip.Addr, *Ip6AndPort) bool { return true },
 	)
 	)
 
 
 	b.Run("no preferred", func(b *testing.B) {
 	b.Run("no preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
 			rl.shouldRebuild = true
 			rl.shouldRebuild = true
-			rl.Rebuild([]*net.IPNet{})
+			rl.Rebuild([]netip.Prefix{})
 		}
 		}
 	})
 	})
 
 
-	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
-	rl.Rebuild([]*net.IPNet{ipNet})
+	ipNet1 := netip.MustParsePrefix("172.17.0.0/16")
+	rl.Rebuild([]netip.Prefix{ipNet1})
 
 
-	assert.NoError(b, err)
 	b.Run("1 preferred", func(b *testing.B) {
 	b.Run("1 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
-			rl.Rebuild([]*net.IPNet{ipNet})
+			rl.Rebuild([]netip.Prefix{ipNet1})
 		}
 		}
 	})
 	})
 
 
-	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
-	rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+	ipNet2 := netip.MustParsePrefix("70.0.0.0/8")
+	rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
 
 
-	assert.NoError(b, err)
 	b.Run("2 preferred", func(b *testing.B) {
 	b.Run("2 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+			rl.Rebuild([]netip.Prefix{ipNet1, ipNet2})
 		}
 		}
 	})
 	})
 
 
-	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
-	rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+	ipNet3 := netip.MustParsePrefix("0.0.0.0/0")
+	rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
 
 
-	assert.NoError(b, err)
 	b.Run("3 preferred", func(b *testing.B) {
 	b.Run("3 preferred", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 		for i := 0; i < b.N; i++ {
-			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+			rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3})
 		}
 		}
 	})
 	})
 }
 }
+
+func newIp4AndPortFromString(s string) *Ip4AndPort {
+	a := netip.MustParseAddrPort(s)
+	v4Addr := a.Addr().As4()
+	return &Ip4AndPort{
+		Ip:   binary.BigEndian.Uint32(v4Addr[:]),
+		Port: uint32(a.Port()),
+	}
+}
+
+func newIp6AndPortFromString(s string) *Ip6AndPort {
+	a := netip.MustParseAddrPort(s)
+	v6Addr := a.Addr().As16()
+	return &Ip6AndPort{
+		Hi:   binary.BigEndian.Uint64(v6Addr[:8]),
+		Lo:   binary.BigEndian.Uint64(v6Addr[8:]),
+		Port: uint32(a.Port()),
+	}
+}

+ 1 - 1
service/service.go

@@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) {
 
 
 	ipNet := device.Cidr()
 	ipNet := device.Cidr()
 	pa := tcpip.ProtocolAddress{
 	pa := tcpip.ProtocolAddress{
-		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
+		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(),
 		Protocol:          ipv4.ProtocolNumber,
 		Protocol:          ipv4.ProtocolNumber,
 	}
 	}
 	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
 	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{

+ 6 - 10
service/service_test.go

@@ -4,7 +4,7 @@ import (
 	"bytes"
 	"bytes"
 	"context"
 	"context"
 	"errors"
 	"errors"
-	"net"
+	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -18,12 +18,8 @@ import (
 
 
 type m map[string]interface{}
 type m map[string]interface{}
 
 
-func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service {
-
-	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
-	copy(vpnIpNet.IP, udpIp)
-
-	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
+func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service {
+	_, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{})
 	caB, err := caCrt.MarshalToPEM()
 	caB, err := caCrt.MarshalToPEM()
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
@@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string,
 }
 }
 
 
 func TestService(t *testing.T) {
 func TestService(t *testing.T) {
-	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{
+	ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{
 		"static_host_map": m{},
 		"static_host_map": m{},
 		"lighthouse": m{
 		"lighthouse": m{
 			"am_lighthouse": true,
 			"am_lighthouse": true,
@@ -94,7 +90,7 @@ func TestService(t *testing.T) {
 			"port": 4243,
 			"port": 4243,
 		},
 		},
 	})
 	})
-	b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{
+	b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{
 		"static_host_map": m{
 		"static_host_map": m{
 			"10.0.0.1": []string{"localhost:4243"},
 			"10.0.0.1": []string{"localhost:4243"},
 		},
 		},

+ 29 - 36
ssh.go

@@ -7,6 +7,7 @@ import (
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"os"
 	"os"
 	"reflect"
 	"reflect"
 	"runtime"
 	"runtime"
@@ -18,9 +19,7 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/sshd"
 	"github.com/slackhq/nebula/sshd"
-	"github.com/slackhq/nebula/udp"
 )
 )
 
 
 type sshListHostMapFlags struct {
 type sshListHostMapFlags struct {
@@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er
 	}
 	}
 
 
 	sort.Slice(hm, func(i, j int) bool {
 	sort.Slice(hm, func(i, j int) bool {
-		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
+		return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0
 	})
 	})
 
 
 	if fs.Json || fs.Pretty {
 	if fs.Json || fs.Pretty {
@@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 		return w.WriteLine("No vpn ip was provided")
 		return w.WriteLine("No vpn ip was provided")
 	}
 	}
 
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
@@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine("No vpn ip was provided")
 		return w.WriteLine("No vpn ip was provided")
 	}
 	}
 
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
@@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("No vpn ip was provided")
 		return w.WriteLine("No vpn ip was provided")
 	}
 	}
 
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
@@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
 	}
 
 
-	var addr *udp.Addr
+	var addr netip.AddrPort
 	if flags.Address != "" {
 	if flags.Address != "" {
-		addr = udp.NewAddrFromString(flags.Address)
-		if addr == nil {
+		addr, err = netip.ParseAddrPort(flags.Address)
+		if err != nil {
 			return w.WriteLine("Address could not be parsed")
 			return w.WriteLine("Address could not be parsed")
 		}
 		}
 	}
 	}
 
 
 	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
 	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
-	if addr != nil {
+	if addr.IsValid() {
 		hostInfo.SetRemote(addr)
 		hostInfo.SetRemote(addr)
 	}
 	}
 
 
@@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine("No address was provided")
 		return w.WriteLine("No address was provided")
 	}
 	}
 
 
-	addr := udp.NewAddrFromString(flags.Address)
-	if addr == nil {
+	addr, err := netip.ParseAddrPort(flags.Address)
+	if err != nil {
 		return w.WriteLine("Address could not be parsed")
 		return w.WriteLine("Address could not be parsed")
 	}
 	}
 
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
@@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 
 
 	cert := ifce.pki.GetCertState().Certificate
 	cert := ifce.pki.GetCertState().Certificate
 	if len(a) > 0 {
 	if len(a) > 0 {
-		parsedIp := net.ParseIP(a[0])
-		if parsedIp == nil {
+		vpnIp, err := netip.ParseAddr(a[0])
+		if err != nil {
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 		}
 
 
-		vpnIp := iputil.Ip2VpnIp(parsedIp)
-		if vpnIp == 0 {
+		if !vpnIp.IsValid() {
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 		}
 
 
@@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		Error          error
 		Error          error
 		Type           string
 		Type           string
 		State          string
 		State          string
-		PeerIp         iputil.VpnIp
+		PeerIp         netip.Addr
 		LocalIndex     uint32
 		LocalIndex     uint32
 		RemoteIndex    uint32
 		RemoteIndex    uint32
-		RelayedThrough []iputil.VpnIp
+		RelayedThrough []netip.Addr
 	}
 	}
 
 
 	type RelayOutput struct {
 	type RelayOutput struct {
-		NebulaIp    iputil.VpnIp
+		NebulaIp    netip.Addr
 		RelayForIps []RelayFor
 		RelayForIps []RelayFor
 	}
 	}
 
 
@@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine("No vpn ip was provided")
 		return w.WriteLine("No vpn ip was provided")
 	}
 	}
 
 
-	parsedIp := net.ParseIP(a[0])
-	if parsedIp == nil {
+	vpnIp, err := netip.ParseAddr(a[0])
+	if err != nil {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	vpnIp := iputil.Ip2VpnIp(parsedIp)
-	if vpnIp == 0 {
+	if !vpnIp.IsValid() {
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 

+ 5 - 7
test/tun.go

@@ -3,23 +3,21 @@ package test
 import (
 import (
 	"errors"
 	"errors"
 	"io"
 	"io"
-	"net"
-
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 )
 
 
 type NoopTun struct{}
 type NoopTun struct{}
 
 
-func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp {
-	return 0
+func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
+	return netip.Addr{}
 }
 }
 
 
 func (NoopTun) Activate() error {
 func (NoopTun) Activate() error {
 	return nil
 	return nil
 }
 }
 
 
-func (NoopTun) Cidr() *net.IPNet {
-	return nil
+func (NoopTun) Cidr() netip.Prefix {
+	return netip.Prefix{}
 }
 }
 
 
 func (NoopTun) Name() string {
 func (NoopTun) Name() string {

+ 5 - 4
timeout_test.go

@@ -1,6 +1,7 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"net/netip"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) {
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)
 
 
 	fps := []firewall.Packet{
 	fps := []firewall.Packet{
-		{LocalIP: 1},
-		{LocalIP: 2},
-		{LocalIP: 3},
-		{LocalIP: 4},
+		{LocalIP: netip.MustParseAddr("0.0.0.1")},
+		{LocalIP: netip.MustParseAddr("0.0.0.2")},
+		{LocalIP: netip.MustParseAddr("0.0.0.3")},
+		{LocalIP: netip.MustParseAddr("0.0.0.4")},
 	}
 	}
 
 
 	tw.Add(fps[0], time.Second*1)
 	tw.Add(fps[0], time.Second*1)

+ 8 - 6
udp/conn.go

@@ -1,6 +1,8 @@
 package udp
 package udp
 
 
 import (
 import (
+	"net/netip"
+
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
@@ -9,7 +11,7 @@ import (
 const MTU = 9001
 const MTU = 9001
 
 
 type EncReader func(
 type EncReader func(
-	addr *Addr,
+	addr netip.AddrPort,
 	out []byte,
 	out []byte,
 	packet []byte,
 	packet []byte,
 	header *header.H,
 	header *header.H,
@@ -22,9 +24,9 @@ type EncReader func(
 
 
 type Conn interface {
 type Conn interface {
 	Rebind() error
 	Rebind() error
-	LocalAddr() (*Addr, error)
+	LocalAddr() (netip.AddrPort, error)
 	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
 	ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int)
-	WriteTo(b []byte, addr *Addr) error
+	WriteTo(b []byte, addr netip.AddrPort) error
 	ReloadConfig(c *config.C)
 	ReloadConfig(c *config.C)
 	Close() error
 	Close() error
 }
 }
@@ -34,13 +36,13 @@ type NoopConn struct{}
 func (NoopConn) Rebind() error {
 func (NoopConn) Rebind() error {
 	return nil
 	return nil
 }
 }
-func (NoopConn) LocalAddr() (*Addr, error) {
-	return nil, nil
+func (NoopConn) LocalAddr() (netip.AddrPort, error) {
+	return netip.AddrPort{}, nil
 }
 }
 func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
 func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) {
 	return
 	return
 }
 }
-func (NoopConn) WriteTo(_ []byte, _ *Addr) error {
+func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
 	return nil
 	return nil
 }
 }
 func (NoopConn) ReloadConfig(_ *config.C) {
 func (NoopConn) ReloadConfig(_ *config.C) {

+ 3 - 2
udp/temp.go

@@ -1,9 +1,10 @@
 package udp
 package udp
 
 
 import (
 import (
-	"github.com/slackhq/nebula/iputil"
+	"net/netip"
 )
 )
 
 
 //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
 //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare
 
 
-type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte)
+// TODO: IPV6-WORK this can likely be removed now
+type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte)

+ 0 - 100
udp/udp_all.go

@@ -1,100 +0,0 @@
-package udp
-
-import (
-	"encoding/json"
-	"fmt"
-	"net"
-	"strconv"
-)
-
-type m map[string]interface{}
-
-type Addr struct {
-	IP   net.IP
-	Port uint16
-}
-
-func NewAddr(ip net.IP, port uint16) *Addr {
-	addr := Addr{IP: make([]byte, net.IPv6len), Port: port}
-	copy(addr.IP, ip.To16())
-	return &addr
-}
-
-func NewAddrFromString(s string) *Addr {
-	ip, port, err := ParseIPAndPort(s)
-	//TODO: handle err
-	_ = err
-	return &Addr{IP: ip.To16(), Port: port}
-}
-
-func (ua *Addr) Equals(t *Addr) bool {
-	if t == nil || ua == nil {
-		return t == nil && ua == nil
-	}
-	return ua.IP.Equal(t.IP) && ua.Port == t.Port
-}
-
-func (ua *Addr) String() string {
-	if ua == nil {
-		return "<nil>"
-	}
-
-	return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port))
-}
-
-func (ua *Addr) MarshalJSON() ([]byte, error) {
-	if ua == nil {
-		return nil, nil
-	}
-
-	return json.Marshal(m{"ip": ua.IP, "port": ua.Port})
-}
-
-func (ua *Addr) Copy() *Addr {
-	if ua == nil {
-		return nil
-	}
-
-	nu := Addr{
-		Port: ua.Port,
-		IP:   make(net.IP, len(ua.IP)),
-	}
-
-	copy(nu.IP, ua.IP)
-	return &nu
-}
-
-type AddrSlice []*Addr
-
-func (a AddrSlice) Equal(b AddrSlice) bool {
-	if len(a) != len(b) {
-		return false
-	}
-
-	for i := range a {
-		if !a[i].Equals(b[i]) {
-			return false
-		}
-	}
-
-	return true
-}
-
-func ParseIPAndPort(s string) (net.IP, uint16, error) {
-	rIp, sPort, err := net.SplitHostPort(s)
-	if err != nil {
-		return nil, 0, err
-	}
-
-	addr, err := net.ResolveIPAddr("ip", rIp)
-	if err != nil {
-		return nil, 0, err
-	}
-
-	iPort, err := strconv.Atoi(sPort)
-	if err != nil {
-		return nil, 0, err
-	}
-
-	return addr.IP, uint16(iPort), nil
-}

+ 2 - 1
udp/udp_android.go

@@ -6,13 +6,14 @@ package udp
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 }
 
 

+ 2 - 1
udp/udp_bsd.go

@@ -9,13 +9,14 @@ package udp
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 }
 
 

+ 2 - 1
udp/udp_darwin.go

@@ -8,13 +8,14 @@ package udp
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 }
 
 

+ 23 - 14
udp/udp_generic.go

@@ -11,6 +11,7 @@ import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/config"
@@ -25,7 +26,7 @@ type GenericConn struct {
 
 
 var _ Conn = &GenericConn{}
 var _ Conn = &GenericConn{}
 
 
-func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	lc := NewListenConfig(multi)
 	lc := NewListenConfig(multi)
 	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
 	pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)))
 	if err != nil {
 	if err != nil {
@@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 	return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc)
 }
 }
 
 
-func (u *GenericConn) WriteTo(b []byte, addr *Addr) error {
-	_, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)})
+func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
+	_, err := u.UDPConn.WriteToUDPAddrPort(b, addr)
 	return err
 	return err
 }
 }
 
 
-func (u *GenericConn) LocalAddr() (*Addr, error) {
+func (u *GenericConn) LocalAddr() (netip.AddrPort, error) {
 	a := u.UDPConn.LocalAddr()
 	a := u.UDPConn.LocalAddr()
 
 
 	switch v := a.(type) {
 	switch v := a.(type) {
 	case *net.UDPAddr:
 	case *net.UDPAddr:
-		addr := &Addr{IP: make([]byte, len(v.IP))}
-		copy(addr.IP, v.IP)
-		addr.Port = uint16(v.Port)
-		return addr, nil
+		addr, ok := netip.AddrFromSlice(v.IP)
+		if !ok {
+			return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP)
+		}
+		return netip.AddrPortFrom(addr, uint16(v.Port)), nil
 
 
 	default:
 	default:
-		return nil, fmt.Errorf("LocalAddr returned: %#v", a)
+		return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a)
 	}
 	}
 }
 }
 
 
@@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f
 	buffer := make([]byte, MTU)
 	buffer := make([]byte, MTU)
 	h := &header.H{}
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
 	fwPacket := &firewall.Packet{}
-	udpAddr := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
 	for {
 	for {
 		// Just read one packet at a time
 		// Just read one packet at a time
-		n, rua, err := u.ReadFromUDP(buffer)
+		n, rua, err := u.ReadFromUDPAddrPort(buffer)
 		if err != nil {
 		if err != nil {
 			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
 			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
 			return
 			return
 		}
 		}
 
 
-		udpAddr.IP = rua.IP
-		udpAddr.Port = uint16(rua.Port)
-		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(
+			netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()),
+			plaintext[:0],
+			buffer[:n],
+			h,
+			fwPacket,
+			lhf,
+			nb,
+			q,
+			cache.Get(u.l),
+		)
 	}
 	}
 }
 }

+ 43 - 32
udp/udp_linux.go

@@ -7,6 +7,7 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"syscall"
 	"syscall"
 	"unsafe"
 	"unsafe"
 
 
@@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) {
 	return ip, false
 	return ip, false
 }
 }
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
-	ipV4, isV4 := maybeIPV4(ip)
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	af := unix.AF_INET6
 	af := unix.AF_INET6
-	if isV4 {
+	if ip.Is4() {
 		af = unix.AF_INET
 		af = unix.AF_INET
 	}
 	}
 	syscall.ForkLock.RLock()
 	syscall.ForkLock.RLock()
@@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
 
 
 	//TODO: support multiple listening IPs (for limiting ipv6)
 	//TODO: support multiple listening IPs (for limiting ipv6)
 	var sa unix.Sockaddr
 	var sa unix.Sockaddr
-	if isV4 {
+	if ip.Is4() {
 		sa4 := &unix.SockaddrInet4{Port: port}
 		sa4 := &unix.SockaddrInet4{Port: port}
-		copy(sa4.Addr[:], ipV4)
+		sa4.Addr = ip.As4()
 		sa = sa4
 		sa = sa4
 	} else {
 	} else {
 		sa6 := &unix.SockaddrInet6{Port: port}
 		sa6 := &unix.SockaddrInet6{Port: port}
-		copy(sa6.Addr[:], ip.To16())
+		sa6.Addr = ip.As16()
 		sa = sa6
 		sa = sa6
 	}
 	}
 	if err = unix.Bind(fd, sa); err != nil {
 	if err = unix.Bind(fd, sa); err != nil {
@@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
 	//l.Println(v, err)
 	//l.Println(v, err)
 
 
-	return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
+	return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
 }
 }
 
 
 func (u *StdConn) Rebind() error {
 func (u *StdConn) Rebind() error {
@@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) {
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
 }
 }
 
 
-func (u *StdConn) LocalAddr() (*Addr, error) {
+func (u *StdConn) LocalAddr() (netip.AddrPort, error) {
 	sa, err := unix.Getsockname(u.sysFd)
 	sa, err := unix.Getsockname(u.sysFd)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return netip.AddrPort{}, err
 	}
 	}
 
 
-	addr := &Addr{}
 	switch sa := sa.(type) {
 	switch sa := sa.(type) {
 	case *unix.SockaddrInet4:
 	case *unix.SockaddrInet4:
-		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
-		addr.Port = uint16(sa.Port)
+		return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil
+
 	case *unix.SockaddrInet6:
 	case *unix.SockaddrInet6:
-		addr.IP = sa.Addr[0:]
-		addr.Port = uint16(sa.Port)
-	}
+		return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil
 
 
-	return addr, nil
+	default:
+		return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa)
+	}
 }
 }
 
 
 func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
 func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
 	plaintext := make([]byte, MTU)
 	plaintext := make([]byte, MTU)
 	h := &header.H{}
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
 	fwPacket := &firewall.Packet{}
-	udpAddr := &Addr{}
+	var ip netip.Addr
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
 	//TODO: should we track this?
 	//TODO: should we track this?
@@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 		//metric.Update(int64(n))
 		//metric.Update(int64(n))
 		for i := 0; i < n; i++ {
 		for i := 0; i < n; i++ {
 			if u.isV4 {
 			if u.isV4 {
-				udpAddr.IP = names[i][4:8]
+				ip, _ = netip.AddrFromSlice(names[i][4:8])
+				//TODO: IPV6-WORK what is not ok?
 			} else {
 			} else {
-				udpAddr.IP = names[i][8:24]
+				ip, _ = netip.AddrFromSlice(names[i][8:24])
+				//TODO: IPV6-WORK what is not ok?
 			}
 			}
-			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
-			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+			r(
+				netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])),
+				plaintext[:0],
+				buffers[i][:msgs[i].Len],
+				h,
+				fwPacket,
+				lhf,
+				nb,
+				q,
+				cache.Get(u.l),
+			)
 		}
 		}
 	}
 	}
 }
 }
@@ -197,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
 	}
 	}
 }
 }
 
 
-func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
+func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
 	if u.isV4 {
 	if u.isV4 {
-		return u.writeTo4(b, addr)
+		return u.writeTo4(b, ip)
 	}
 	}
-	return u.writeTo6(b, addr)
+	return u.writeTo6(b, ip)
 }
 }
 
 
-func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
+func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error {
 	var rsa unix.RawSockaddrInet6
 	var rsa unix.RawSockaddrInet6
 	rsa.Family = unix.AF_INET6
 	rsa.Family = unix.AF_INET6
+	rsa.Addr = ip.Addr().As16()
+	port := ip.Port()
 	// Little Endian -> Network Endian
 	// Little Endian -> Network Endian
-	rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
-	copy(rsa.Addr[:], addr.IP.To16())
+	rsa.Port = (port >> 8) | ((port & 0xff) << 8)
 
 
 	for {
 	for {
 		_, _, err := unix.Syscall6(
 		_, _, err := unix.Syscall6(
@@ -232,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
 	}
 	}
 }
 }
 
 
-func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
-	addrV4, isAddrV4 := maybeIPV4(addr.IP)
-	if !isAddrV4 {
+func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error {
+	if !ip.Addr().Is4() {
 		return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
 		return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
 	}
 	}
 
 
 	var rsa unix.RawSockaddrInet4
 	var rsa unix.RawSockaddrInet4
 	rsa.Family = unix.AF_INET
 	rsa.Family = unix.AF_INET
+	rsa.Addr = ip.Addr().As4()
+	port := ip.Port()
 	// Little Endian -> Network Endian
 	// Little Endian -> Network Endian
-	rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
-	copy(rsa.Addr[:], addrV4)
+	rsa.Port = (port >> 8) | ((port & 0xff) << 8)
 
 
 	for {
 	for {
 		_, _, err := unix.Syscall6(
 		_, _, err := unix.Syscall6(

+ 2 - 1
udp/udp_netbsd.go

@@ -8,13 +8,14 @@ package udp
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"golang.org/x/sys/unix"
 	"golang.org/x/sys/unix"
 )
 )
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	return NewGenericListener(l, ip, port, multi, batch)
 	return NewGenericListener(l, ip, port, multi, batch)
 }
 }
 
 

+ 22 - 21
udp/udp_rio_windows.go

@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net"
 	"net"
+	"net/netip"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"syscall"
 	"syscall"
@@ -61,16 +62,14 @@ type RIOConn struct {
 	results [packetsPerRing]winrio.Result
 	results [packetsPerRing]winrio.Result
 }
 }
 
 
-func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) {
+func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) {
 	if !winrio.Initialize() {
 	if !winrio.Initialize() {
 		return nil, errors.New("could not initialize winrio")
 		return nil, errors.New("could not initialize winrio")
 	}
 	}
 
 
 	u := &RIOConn{l: l}
 	u := &RIOConn{l: l}
 
 
-	addr := [16]byte{}
-	copy(addr[:], ip.To16())
-	err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port})
+	err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port})
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("bind: %w", err)
 		return nil, fmt.Errorf("bind: %w", err)
 	}
 	}
@@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 	buffer := make([]byte, MTU)
 	buffer := make([]byte, MTU)
 	h := &header.H{}
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
 	fwPacket := &firewall.Packet{}
-	udpAddr := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
 	for {
 	for {
@@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew
 			return
 			return
 		}
 		}
 
 
-		udpAddr.IP = rua.Addr[:]
-		p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port))
-		p[0] = byte(rua.Port >> 8)
-		p[1] = byte(rua.Port)
-		r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(
+			netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)),
+			plaintext[:0],
+			buffer[:n],
+			h,
+			fwPacket,
+			lhf,
+			nb,
+			q,
+			cache.Get(u.l),
+		)
 	}
 	}
 }
 }
 
 
@@ -231,7 +235,7 @@ retry:
 	return n, ep, nil
 	return n, ep, nil
 }
 }
 
 
-func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
+func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
 	if !u.isOpen.Load() {
 	if !u.isOpen.Load() {
 		return net.ErrClosed
 		return net.ErrClosed
 	}
 	}
@@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
 
 
 	packet := u.tx.Push()
 	packet := u.tx.Push()
 	packet.addr.Family = windows.AF_INET6
 	packet.addr.Family = windows.AF_INET6
-	p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port))
-	p[0] = byte(addr.Port >> 8)
-	p[1] = byte(addr.Port)
-	copy(packet.addr.Addr[:], addr.IP.To16())
+	packet.addr.Addr = ip.Addr().As16()
+	port := ip.Port()
+	packet.addr.Port = (port >> 8) | ((port & 0xff) << 8)
 	copy(packet.data[:], buf)
 	copy(packet.data[:], buf)
 
 
 	dataBuffer := &winrio.Buffer{
 	dataBuffer := &winrio.Buffer{
@@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error {
 	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 	return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 }
 }
 
 
-func (u *RIOConn) LocalAddr() (*Addr, error) {
+func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
 	sa, err := windows.Getsockname(u.sock)
 	sa, err := windows.Getsockname(u.sock)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return netip.AddrPort{}, err
 	}
 	}
 
 
 	v6 := sa.(*windows.SockaddrInet6)
 	v6 := sa.(*windows.SockaddrInet6)
-	return &Addr{
-		IP:   v6.Addr[:],
-		Port: uint16(v6.Port),
-	}, nil
+	return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil
+
 }
 }
 
 
 func (u *RIOConn) Rebind() error {
 func (u *RIOConn) Rebind() error {

+ 17 - 32
udp/udp_tester.go

@@ -4,9 +4,8 @@
 package udp
 package udp
 
 
 import (
 import (
-	"fmt"
 	"io"
 	"io"
-	"net"
+	"net/netip"
 	"sync/atomic"
 	"sync/atomic"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
@@ -16,30 +15,24 @@ import (
 )
 )
 
 
 type Packet struct {
 type Packet struct {
-	ToIp     net.IP
-	ToPort   uint16
-	FromIp   net.IP
-	FromPort uint16
-	Data     []byte
+	To   netip.AddrPort
+	From netip.AddrPort
+	Data []byte
 }
 }
 
 
 func (u *Packet) Copy() *Packet {
 func (u *Packet) Copy() *Packet {
 	n := &Packet{
 	n := &Packet{
-		ToIp:     make(net.IP, len(u.ToIp)),
-		ToPort:   u.ToPort,
-		FromIp:   make(net.IP, len(u.FromIp)),
-		FromPort: u.FromPort,
-		Data:     make([]byte, len(u.Data)),
+		To:   u.To,
+		From: u.From,
+		Data: make([]byte, len(u.Data)),
 	}
 	}
 
 
-	copy(n.ToIp, u.ToIp)
-	copy(n.FromIp, u.FromIp)
 	copy(n.Data, u.Data)
 	copy(n.Data, u.Data)
 	return n
 	return n
 }
 }
 
 
 type TesterConn struct {
 type TesterConn struct {
-	Addr *Addr
+	Addr netip.AddrPort
 
 
 	RxPackets chan *Packet // Packets to receive into nebula
 	RxPackets chan *Packet // Packets to receive into nebula
 	TxPackets chan *Packet // Packets transmitted outside by nebula
 	TxPackets chan *Packet // Packets transmitted outside by nebula
@@ -48,9 +41,9 @@ type TesterConn struct {
 	l      *logrus.Logger
 	l      *logrus.Logger
 }
 }
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) {
 	return &TesterConn{
 	return &TesterConn{
-		Addr:      &Addr{ip, uint16(port)},
+		Addr:      netip.AddrPortFrom(ip, uint16(port)),
 		RxPackets: make(chan *Packet, 10),
 		RxPackets: make(chan *Packet, 10),
 		TxPackets: make(chan *Packet, 10),
 		TxPackets: make(chan *Packet, 10),
 		l:         l,
 		l:         l,
@@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) {
 	}
 	}
 	if u.l.Level >= logrus.DebugLevel {
 	if u.l.Level >= logrus.DebugLevel {
 		u.l.WithField("header", h).
 		u.l.WithField("header", h).
-			WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
+			WithField("udpAddr", packet.From).
 			WithField("dataLen", len(packet.Data)).
 			WithField("dataLen", len(packet.Data)).
 			Debug("UDP receiving injected packet")
 			Debug("UDP receiving injected packet")
 	}
 	}
@@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet {
 // Below this is boilerplate implementation to make nebula actually work
 // Below this is boilerplate implementation to make nebula actually work
 //********************************************************************************************************************//
 //********************************************************************************************************************//
 
 
-func (u *TesterConn) WriteTo(b []byte, addr *Addr) error {
+func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
 	if u.closed.Load() {
 	if u.closed.Load() {
 		return io.ErrClosedPipe
 		return io.ErrClosedPipe
 	}
 	}
 
 
 	p := &Packet{
 	p := &Packet{
-		Data:     make([]byte, len(b), len(b)),
-		FromIp:   make([]byte, 16),
-		FromPort: u.Addr.Port,
-		ToIp:     make([]byte, 16),
-		ToPort:   addr.Port,
+		Data: make([]byte, len(b), len(b)),
+		From: u.Addr,
+		To:   addr,
 	}
 	}
 
 
 	copy(p.Data, b)
 	copy(p.Data, b)
-	copy(p.ToIp, addr.IP.To16())
-	copy(p.FromIp, u.Addr.IP.To16())
-
 	u.TxPackets <- p
 	u.TxPackets <- p
 	return nil
 	return nil
 }
 }
@@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
 	plaintext := make([]byte, MTU)
 	plaintext := make([]byte, MTU)
 	h := &header.H{}
 	h := &header.H{}
 	fwPacket := &firewall.Packet{}
 	fwPacket := &firewall.Packet{}
-	ua := &Addr{IP: make([]byte, 16)}
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 
 
 	for {
 	for {
@@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi
 		if !ok {
 		if !ok {
 			return
 			return
 		}
 		}
-		ua.Port = p.FromPort
-		copy(ua.IP, p.FromIp.To16())
-		r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
+		r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l))
 	}
 	}
 }
 }
 
 
@@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() {
 	return func() {}
 	return func() {}
 }
 }
 
 
-func (u *TesterConn) LocalAddr() (*Addr, error) {
+func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
 	return u.Addr, nil
 	return u.Addr, nil
 }
 }
 
 

+ 2 - 1
udp/udp_windows.go

@@ -6,12 +6,13 @@ package udp
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"net/netip"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
 
 
-func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
+func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) {
 	if multi {
 	if multi {
 		//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
 		//NOTE: Technically we can support it with RIO but it wouldn't be at the socket level
 		// The udp stack would need to be reworked to hide away the implementation differences between
 		// The udp stack would need to be reworked to hide away the implementation differences between