瀏覽代碼

Use generics for CIDRTrees to avoid casting issues (#1004)

Nate Brown 1 年之前
父節點
當前提交
5181cb0474

+ 14 - 29
allow_list.go

@@ -12,7 +12,7 @@ import (
 
 type AllowList struct {
 	// The values of this cidrTree are `bool`, signifying allow/deny
-	cidrTree *cidr.Tree6
+	cidrTree *cidr.Tree6[bool]
 }
 
 type RemoteAllowList struct {
@@ -20,7 +20,7 @@ type RemoteAllowList struct {
 
 	// Inside Range Specific, keys of this tree are inside CIDRs and values
 	// are *AllowList
-	insideAllowLists *cidr.Tree6
+	insideAllowLists *cidr.Tree6[*AllowList]
 }
 
 type LocalAllowList struct {
@@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
 		return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
 	}
 
-	tree := cidr.NewTree6()
+	tree := cidr.NewTree6[bool]()
 
 	// Keep track of the rules we have added for both ipv4 and ipv6
 	type allowListRules struct {
@@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
 	return nameRules, nil
 }
 
-func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
+func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	remoteAllowRanges := cidr.NewTree6()
+	remoteAllowRanges := cidr.NewTree6[*AllowList]()
 
 	rawMap, ok := value.(map[interface{}]interface{})
 	if !ok {
@@ -257,13 +257,8 @@ func (al *AllowList) Allow(ip net.IP) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContains(ip)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContains(ip)
+	return result
 }
 
 func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
@@ -271,13 +266,8 @@ func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContainsIpV4(ip)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
+	return result
 }
 
 func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
@@ -285,13 +275,8 @@ func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
 		return true
 	}
 
-	result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
-	switch v := result.(type) {
-	case bool:
-		return v
-	default:
-		panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
-	}
+	_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
+	return result
 }
 
 func (al *LocalAllowList) Allow(ip net.IP) bool {
@@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {
 
 func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
 	if al.insideAllowLists != nil {
-		inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
-		if inside != nil {
-			return inside.(*AllowList)
+		ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
+		if ok {
+			return inside
 		}
 	}
 	return nil

+ 1 - 1
allow_list_test.go

@@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
 func TestAllowList_Allow(t *testing.T) {
 	assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))
 
-	tree := cidr.NewTree6()
+	tree := cidr.NewTree6[bool]()
 	tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
 	tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
 	tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)

+ 2 - 2
calculated_remote.go

@@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
 	return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
 }
 
-func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
+func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
 	value := c.Get(k)
 	if value == nil {
 		return nil, nil
 	}
 
-	calculatedRemotes := cidr.NewTree4()
+	calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()
 
 	rawMap, ok := value.(map[any]any)
 	if !ok {

+ 34 - 28
cidr/tree4.go

@@ -6,35 +6,36 @@ import (
 	"github.com/slackhq/nebula/iputil"
 )
 
-type Node struct {
-	left   *Node
-	right  *Node
-	parent *Node
-	value  interface{}
+type Node[T any] struct {
+	left     *Node[T]
+	right    *Node[T]
+	parent   *Node[T]
+	hasValue bool
+	value    T
 }
 
-type entry struct {
+type entry[T any] struct {
 	CIDR  *net.IPNet
-	Value *interface{}
+	Value T
 }
 
-type Tree4 struct {
-	root *Node
-	list []entry
+type Tree4[T any] struct {
+	root *Node[T]
+	list []entry[T]
 }
 
 const (
 	startbit = iputil.VpnIp(0x80000000)
 )
 
-func NewTree4() *Tree4 {
-	tree := new(Tree4)
-	tree.root = &Node{}
-	tree.list = []entry{}
+func NewTree4[T any]() *Tree4[T] {
+	tree := new(Tree4[T])
+	tree.root = &Node[T]{}
+	tree.list = []entry[T]{}
 	return tree
 }
 
-func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
+func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
 	bit := startbit
 	node := tree.root
 	next := tree.root
@@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 			}
 		}
 
-		tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
+		tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
 		node.value = val
+		node.hasValue = true
 		return
 	}
 
 	// Build up the rest of the tree we don't already have
 	for bit&mask != 0 {
-		next = &Node{}
+		next = &Node[T]{}
 		next.parent = node
 
 		if ip&bit != 0 {
@@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
-	tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
+	node.hasValue = true
+	tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
 }
 
 // Contains finds the first match, which may be the least specific
-func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 
 	for node != nil {
-		if node.value != nil {
-			return node.value
+		if node.hasValue {
+			return true, node.value
 		}
 
 		if ip&bit != 0 {
@@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
 
 	}
 
-	return value
+	return false, value
 }
 
 // MostSpecificContains finds the most specific match
-func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 
 	for node != nil {
-		if node.value != nil {
+		if node.hasValue {
 			value = node.value
+			ok = true
 		}
 
 		if ip&bit != 0 {
@@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
 		bit >>= 1
 	}
 
-	return value
+	return ok, value
 }
 
 // Match finds the most specific match
-func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
+// TODO this is exact match
+func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root
 	lastNode := node
@@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
 
 	if bit == 0 && lastNode != nil {
 		value = lastNode.value
+		ok = true
 	}
-	return value
+	return ok, value
 }
 
 // List will return all CIDRs and their current values. Do not modify the contents!
-func (tree *Tree4) List() []entry {
+func (tree *Tree4[T]) List() []entry[T] {
 	return tree.list
 }

+ 71 - 47
cidr/tree4_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestCIDRTree_List(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.0.0.0/8"), "2")
 	tree.AddCIDR(Parse("1.0.0.0/16"), "3")
@@ -17,13 +17,13 @@ func TestCIDRTree_List(t *testing.T) {
 	list := tree.List()
 	assert.Len(t, list, 2)
 	assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String())
-	assert.Equal(t, "2", *list[0].Value)
+	assert.Equal(t, "2", list[0].Value)
 	assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String())
-	assert.Equal(t, "4", *list[1].Value)
+	assert.Equal(t, "4", list[1].Value)
 }
 
 func TestCIDRTree_Contains(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -33,35 +33,43 @@ func TestCIDRTree_Contains(t *testing.T) {
 	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4a", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4a", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func TestCIDRTree_MostSpecificContains(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -71,59 +79,75 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
 	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4b", "4.1.1.2"},
+		{true, "4c", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func TestCIDRTree_Match(t *testing.T) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
 	tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1a", "4.1.1.0"},
-		{"1b", "4.1.1.1"},
+		{true, "1a", "4.1.1.0"},
+		{true, "1b", "4.1.1.1"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP))))
+		ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree4()
+	tree = NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))))
-	assert.Equal(t, "cool", tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))))
+	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
 }
 
 func BenchmarkCIDRTree_Contains(b *testing.B) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
 	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
@@ -145,7 +169,7 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
 }
 
 func BenchmarkCIDRTree_Match(b *testing.B) {
-	tree := NewTree4()
+	tree := NewTree4[string]()
 	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
 	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
 	tree.AddCIDR(Parse("192.2.1.1/32"), "1")

+ 24 - 20
cidr/tree6.go

@@ -8,20 +8,20 @@ import (
 
 const startbit6 = uint64(1 << 63)
 
-type Tree6 struct {
-	root4 *Node
-	root6 *Node
+type Tree6[T any] struct {
+	root4 *Node[T]
+	root6 *Node[T]
 }
 
-func NewTree6() *Tree6 {
-	tree := new(Tree6)
-	tree.root4 = &Node{}
-	tree.root6 = &Node{}
+func NewTree6[T any]() *Tree6[T] {
+	tree := new(Tree6[T])
+	tree.root4 = &Node[T]{}
+	tree.root6 = &Node[T]{}
 	return tree
 }
 
-func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
-	var node, next *Node
+func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) {
+	var node, next *Node[T]
 
 	cidrIP, ipv4 := isIPV4(cidr.IP)
 	if ipv4 {
@@ -56,7 +56,7 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 		// Build up the rest of the tree we don't already have
 		for bit&mask != 0 {
-			next = &Node{}
+			next = &Node[T]{}
 			next.parent = node
 
 			if ip&bit != 0 {
@@ -72,11 +72,12 @@ func (tree *Tree6) AddCIDR(cidr *net.IPNet, val interface{}) {
 
 	// Final node marks our cidr, set the value
 	node.value = val
+	node.hasValue = true
 }
 
 // Finds the most specific match
-func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
-	var node *Node
+func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) {
+	var node *Node[T]
 
 	wholeIP, ipv4 := isIPV4(ip)
 	if ipv4 {
@@ -90,8 +91,9 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
 		bit := startbit
 
 		for node != nil {
-			if node.value != nil {
+			if node.hasValue {
 				value = node.value
+				ok = true
 			}
 
 			if bit == 0 {
@@ -108,16 +110,17 @@ func (tree *Tree6) MostSpecificContains(ip net.IP) (value interface{}) {
 		}
 	}
 
-	return value
+	return ok, value
 }
 
-func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{}) {
+func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) {
 	bit := startbit
 	node := tree.root4
 
 	for node != nil {
-		if node.value != nil {
+		if node.hasValue {
 			value = node.value
+			ok = true
 		}
 
 		if ip&bit != 0 {
@@ -129,10 +132,10 @@ func (tree *Tree6) MostSpecificContainsIpV4(ip iputil.VpnIp) (value interface{})
 		bit >>= 1
 	}
 
-	return value
+	return ok, value
 }
 
-func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
+func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) {
 	ip := hi
 	node := tree.root6
 
@@ -140,8 +143,9 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 		bit := startbit6
 
 		for node != nil {
-			if node.value != nil {
+			if node.hasValue {
 				value = node.value
+				ok = true
 			}
 
 			if bit == 0 {
@@ -160,7 +164,7 @@ func (tree *Tree6) MostSpecificContainsIpV6(hi, lo uint64) (value interface{}) {
 		ip = lo
 	}
 
-	return value
+	return ok, value
 }
 
 func isIPV4(ip net.IP) (net.IP, bool) {

+ 45 - 28
cidr/tree6_test.go

@@ -9,7 +9,7 @@ import (
 )
 
 func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
-	tree := NewTree6()
+	tree := NewTree6[string]()
 	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
 	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
 	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
@@ -22,53 +22,68 @@ func TestCIDR6Tree_MostSpecificContains(t *testing.T) {
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"1", "1.0.0.0"},
-		{"1", "1.255.255.255"},
-		{"2", "2.1.0.0"},
-		{"2", "2.1.255.255"},
-		{"3", "3.1.1.0"},
-		{"3", "3.1.1.255"},
-		{"4a", "4.1.1.255"},
-		{"4b", "4.1.1.2"},
-		{"4c", "4.1.1.1"},
-		{"5", "240.0.0.0"},
-		{"5", "255.255.255.255"},
-		{"6a", "1:2:0:4:1:1:1:1"},
-		{"6b", "1:2:0:4:5:1:1:1"},
-		{"6c", "1:2:0:4:5:0:0:0"},
-		{nil, "239.0.0.0"},
-		{nil, "4.1.2.2"},
+		{true, "1", "1.0.0.0"},
+		{true, "1", "1.255.255.255"},
+		{true, "2", "2.1.0.0"},
+		{true, "2", "2.1.255.255"},
+		{true, "3", "3.1.1.0"},
+		{true, "3", "3.1.1.255"},
+		{true, "4a", "4.1.1.255"},
+		{true, "4b", "4.1.1.2"},
+		{true, "4c", "4.1.1.1"},
+		{true, "5", "240.0.0.0"},
+		{true, "5", "255.255.255.255"},
+		{true, "6a", "1:2:0:4:1:1:1:1"},
+		{true, "6b", "1:2:0:4:5:1:1:1"},
+		{true, "6c", "1:2:0:4:5:0:0:0"},
+		{false, "", "239.0.0.0"},
+		{false, "", "4.1.2.2"},
 	}
 
 	for _, tt := range tests {
-		assert.Equal(t, tt.Result, tree.MostSpecificContains(net.ParseIP(tt.IP)))
+		ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP))
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 
-	tree = NewTree6()
+	tree = NewTree6[string]()
 	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
 	tree.AddCIDR(Parse("::/0"), "cool6")
-	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("0.0.0.0")))
-	assert.Equal(t, "cool", tree.MostSpecificContains(net.ParseIP("255.255.255.255")))
-	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("::")))
-	assert.Equal(t, "cool6", tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")))
+	ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("::"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool6", r)
+
+	ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8"))
+	assert.True(t, ok)
+	assert.Equal(t, "cool6", r)
 }
 
 func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
-	tree := NewTree6()
+	tree := NewTree6[string]()
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a")
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b")
 	tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c")
 
 	tests := []struct {
+		Found  bool
 		Result interface{}
 		IP     string
 	}{
-		{"6a", "1:2:0:4:1:1:1:1"},
-		{"6b", "1:2:0:4:5:1:1:1"},
-		{"6c", "1:2:0:4:5:0:0:0"},
+		{true, "6a", "1:2:0:4:1:1:1:1"},
+		{true, "6b", "1:2:0:4:5:1:1:1"},
+		{true, "6c", "1:2:0:4:5:0:0:0"},
 	}
 
 	for _, tt := range tests {
@@ -76,6 +91,8 @@ func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) {
 		hi := binary.BigEndian.Uint64(ip[:8])
 		lo := binary.BigEndian.Uint64(ip[8:])
 
-		assert.Equal(t, tt.Result, tree.MostSpecificContainsIpV6(hi, lo))
+		ok, r := tree.MostSpecificContainsIpV6(hi, lo)
+		assert.Equal(t, tt.Found, ok)
+		assert.Equal(t, tt.Result, r)
 	}
 }

+ 22 - 14
firewall.go

@@ -57,7 +57,7 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps *cidr.Tree4
+	localIps *cidr.Tree4[struct{}]
 
 	rules        string
 	rulesVersion uint16
@@ -110,8 +110,8 @@ type FirewallRule struct {
 	Any       bool
 	Hosts     map[string]struct{}
 	Groups    [][]string
-	CIDR      *cidr.Tree4
-	LocalCIDR *cidr.Tree4
+	CIDR      *cidr.Tree4[struct{}]
+	LocalCIDR *cidr.Tree4[struct{}]
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
@@ -137,7 +137,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		max = defaultTimeout
 	}
 
-	localIps := cidr.NewTree4()
+	localIps := cidr.NewTree4[struct{}]()
 	for _, ip := range c.Details.Ips {
 		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}
@@ -391,7 +391,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
 
 	// Make sure remote address matches nebula certificate
 	if remoteCidr := h.remoteCidr; remoteCidr != nil {
-		if remoteCidr.Contains(fp.RemoteIP) == nil {
+		ok, _ := remoteCidr.Contains(fp.RemoteIP)
+		if !ok {
 			f.metrics(incoming).droppedRemoteIP.Inc(1)
 			return ErrInvalidRemoteIP
 		}
@@ -404,7 +405,8 @@ func (f *Firewall) Drop(packet []byte, fp firewall.Packet, incoming bool, h *Hos
 	}
 
 	// Make sure we are supposed to be handling this local ip address
-	if f.localIps.Contains(fp.LocalIP) == nil {
+	ok, _ := f.localIps.Contains(fp.LocalIP)
+	if !ok {
 		f.metrics(incoming).droppedLocalIP.Inc(1)
 		return ErrInvalidLocalIP
 	}
@@ -657,8 +659,8 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
 		return &FirewallRule{
 			Hosts:     make(map[string]struct{}),
 			Groups:    make([][]string, 0),
-			CIDR:      cidr.NewTree4(),
-			LocalCIDR: cidr.NewTree4(),
+			CIDR:      cidr.NewTree4[struct{}](),
+			LocalCIDR: cidr.NewTree4[struct{}](),
 		}
 	}
 
@@ -726,8 +728,8 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
-		fr.CIDR = cidr.NewTree4()
-		fr.LocalCIDR = cidr.NewTree4()
+		fr.CIDR = cidr.NewTree4[struct{}]()
+		fr.LocalCIDR = cidr.NewTree4[struct{}]()
 	} else {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
@@ -809,12 +811,18 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		}
 	}
 
-	if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
-		return true
+	if fr.CIDR != nil {
+		ok, _ := fr.CIDR.Contains(p.RemoteIP)
+		if ok {
+			return true
+		}
 	}
 
-	if fr.LocalCIDR != nil && fr.LocalCIDR.Contains(p.LocalIP) != nil {
-		return true
+	if fr.LocalCIDR != nil {
+		ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
+		if ok {
+			return true
+		}
 	}
 
 	// No host, group, or cidr matched, bye bye

+ 8 - 4
firewall_test.go

@@ -92,14 +92,16 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
 	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
 	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
@@ -114,8 +116,10 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
 	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
 	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
-	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))
+	ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
+	ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.True(t, ok)
 
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion

+ 2 - 2
hostmap.go

@@ -205,7 +205,7 @@ type HostInfo struct {
 	localIndexId    uint32
 	vpnIp           iputil.VpnIp
 	recvError       atomic.Uint32
-	remoteCidr      *cidr.Tree4
+	remoteCidr      *cidr.Tree4[struct{}]
 	relayState      RelayState
 
 	// HandshakePacket records the packets used to create this hostinfo
@@ -633,7 +633,7 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
 		return
 	}
 
-	remoteCidr := cidr.NewTree4()
+	remoteCidr := cidr.NewTree4[struct{}]()
 	for _, ip := range c.Details.Ips {
 		remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
 	}

+ 4 - 5
lighthouse.go

@@ -74,7 +74,7 @@ type LightHouse struct {
 	// IP's of relays that can be used by peers to access me
 	relaysForMe atomic.Pointer[[]iputil.VpnIp]
 
-	calculatedRemotes atomic.Pointer[cidr.Tree4] // Maps VpnIp to []*calculatedRemote
+	calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
 
 	metrics           *MessageMetrics
 	metricHolepunchTx metrics.Counter
@@ -166,7 +166,7 @@ func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp {
 	return *lh.relaysForMe.Load()
 }
 
-func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4 {
+func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] {
 	return lh.calculatedRemotes.Load()
 }
 
@@ -594,11 +594,10 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool {
 	if tree == nil {
 		return false
 	}
-	value := tree.MostSpecificContains(vpnIp)
-	if value == nil {
+	ok, calculatedRemotes := tree.MostSpecificContains(vpnIp)
+	if !ok {
 		return false
 	}
-	calculatedRemotes := value.([]*calculatedRemote)
 
 	var calculated []*Ip4AndPort
 	for _, cr := range calculatedRemotes {

+ 2 - 2
overlay/route.go

@@ -21,8 +21,8 @@ type Route struct {
 	Install bool
 }
 
-func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4, error) {
-	routeTree := cidr.NewTree4()
+func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) {
+	routeTree := cidr.NewTree4[iputil.VpnIp]()
 	for _, r := range routes {
 		if !allowMTU && r.MTU > 0 {
 			l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)

+ 8 - 10
overlay/route_test.go

@@ -265,18 +265,16 @@ func Test_makeRouteTree(t *testing.T) {
 	assert.NoError(t, err)
 
 	ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2"))
-	r := routeTree.MostSpecificContains(ip)
-	assert.NotNil(t, r)
-	assert.IsType(t, iputil.VpnIp(0), r)
-	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
+	ok, r := routeTree.MostSpecificContains(ip)
+	assert.True(t, ok)
+	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r)
 
 	ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1"))
-	r = routeTree.MostSpecificContains(ip)
-	assert.NotNil(t, r)
-	assert.IsType(t, iputil.VpnIp(0), r)
-	assert.EqualValues(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
+	ok, r = routeTree.MostSpecificContains(ip)
+	assert.True(t, ok)
+	assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r)
 
 	ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1"))
-	r = routeTree.MostSpecificContains(ip)
-	assert.Nil(t, r)
+	ok, r = routeTree.MostSpecificContains(ip)
+	assert.False(t, ok)
 }

+ 4 - 4
overlay/tun_darwin.go

@@ -25,7 +25,7 @@ type tun struct {
 	cidr       *net.IPNet
 	DefaultMTU int
 	Routes     []Route
-	routeTree  *cidr.Tree4
+	routeTree  *cidr.Tree4[iputil.VpnIp]
 	l          *logrus.Logger
 
 	// cache out buffer since we need to prepend 4 bytes for tun metadata
@@ -304,9 +304,9 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
+	ok, r := t.routeTree.MostSpecificContains(ip)
+	if ok {
+		return r
 	}
 
 	return 0

+ 3 - 7
overlay/tun_freebsd.go

@@ -48,7 +48,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -192,12 +192,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 5 - 9
overlay/tun_linux.go

@@ -30,7 +30,7 @@ type tun struct {
 	TXQueueLen int
 
 	Routes          []Route
-	routeTree       atomic.Pointer[cidr.Tree4]
+	routeTree       atomic.Pointer[cidr.Tree4[iputil.VpnIp]]
 	routeChan       chan struct{}
 	useSystemRoutes bool
 
@@ -154,12 +154,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.Load().MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.Load().MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Write(b []byte) (int, error) {
@@ -380,7 +376,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 		return
 	}
 
-	newTree := cidr.NewTree4()
+	newTree := cidr.NewTree4[iputil.VpnIp]()
 	if r.Type == unix.RTM_NEWROUTE {
 		for _, oldR := range t.routeTree.Load().List() {
 			newTree.AddCIDR(oldR.CIDR, oldR.Value)
@@ -392,7 +388,7 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
 	} else {
 		gw := iputil.Ip2VpnIp(r.Gw)
 		for _, oldR := range t.routeTree.Load().List() {
-			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && *oldR.Value != nil && (*oldR.Value).(iputil.VpnIp) == gw {
+			if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw {
 				// This is the record to delete
 				t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
 				continue

+ 3 - 7
overlay/tun_netbsd.go

@@ -29,7 +29,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -134,12 +134,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_openbsd.go

@@ -23,7 +23,7 @@ type tun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	io.ReadWriteCloser
@@ -115,12 +115,8 @@ func (t *tun) Activate() error {
 }
 
 func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *tun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_tester.go

@@ -19,7 +19,7 @@ type TestTun struct {
 	Device    string
 	cidr      *net.IPNet
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 	l         *logrus.Logger
 
 	closed    atomic.Bool
@@ -83,12 +83,8 @@ func (t *TestTun) Get(block bool) []byte {
 //********************************************************************************************************************//
 
 func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *TestTun) Activate() error {

+ 3 - 7
overlay/tun_water_windows.go

@@ -18,7 +18,7 @@ type waterTun struct {
 	cidr      *net.IPNet
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 
 	*water.Interface
 }
@@ -97,12 +97,8 @@ func (t *waterTun) Activate() error {
 }
 
 func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *waterTun) Cidr() *net.IPNet {

+ 3 - 7
overlay/tun_wintun_windows.go

@@ -24,7 +24,7 @@ type winTun struct {
 	prefix    netip.Prefix
 	MTU       int
 	Routes    []Route
-	routeTree *cidr.Tree4
+	routeTree *cidr.Tree4[iputil.VpnIp]
 
 	tun *wintun.NativeTun
 }
@@ -146,12 +146,8 @@ func (t *winTun) Activate() error {
 }
 
 func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp {
-	r := t.routeTree.MostSpecificContains(ip)
-	if r != nil {
-		return r.(iputil.VpnIp)
-	}
-
-	return 0
+	_, r := t.routeTree.MostSpecificContains(ip)
+	return r
 }
 
 func (t *winTun) Cidr() *net.IPNet {