Browse Source

Add a way to find the most specific network

Nate Brown 5 years ago
parent
commit
c1182869c4
2 changed files with 62 additions and 0 deletions
  1. 23 0
      cidr_radix.go
  2. 39 0
      cidr_radix_test.go

+ 23 - 0
cidr_radix.go

@@ -99,6 +99,29 @@ func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
 	return value
 	return value
 }
 }
 
 
+// Finds the most specific match
+func (tree *CIDRTree) MostSpecificContains(ip uint32) (value interface{}) {
+	bit := startbit
+	node := tree.root
+
+	for node != nil {
+		if node.value != nil {
+			value = node.value
+		}
+
+		if ip&bit != 0 {
+			node = node.right
+		} else {
+			node = node.left
+		}
+
+		bit >>= 1
+
+	}
+
+	return value
+}
+
 // Finds the most specific match
 // Finds the most specific match
 func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
 func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
 	bit := startbit
 	bit := startbit

+ 39 - 0
cidr_radix_test.go

@@ -45,6 +45,45 @@ func TestCIDRTree_Contains(t *testing.T) {
 	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
 	assert.Equal(t, "cool", tree.Contains(ip2int(net.ParseIP("255.255.255.255"))))
 }
 }
 
 
+func TestCIDRTree_MostSpecificContains(t *testing.T) {
+	tree := NewCIDRTree()
+	tree.AddCIDR(getCIDR("1.0.0.0/8"), "1")
+	tree.AddCIDR(getCIDR("2.1.0.0/16"), "2")
+	tree.AddCIDR(getCIDR("3.1.1.0/24"), "3")
+	tree.AddCIDR(getCIDR("4.1.1.0/24"), "4a")
+	tree.AddCIDR(getCIDR("4.1.1.0/30"), "4b")
+	tree.AddCIDR(getCIDR("4.1.1.1/32"), "4c")
+	tree.AddCIDR(getCIDR("254.0.0.0/4"), "5")
+
+	tests := []struct {
+		Result interface{}
+		IP     string
+	}{
+		{"1", "1.0.0.0"},
+		{"1", "1.255.255.255"},
+		{"2", "2.1.0.0"},
+		{"2", "2.1.255.255"},
+		{"3", "3.1.1.0"},
+		{"3", "3.1.1.255"},
+		{"4a", "4.1.1.255"},
+		{"4b", "4.1.1.2"},
+		{"4c", "4.1.1.1"},
+		{"5", "240.0.0.0"},
+		{"5", "255.255.255.255"},
+		{nil, "239.0.0.0"},
+		{nil, "4.1.2.2"},
+	}
+
+	for _, tt := range tests {
+		assert.Equal(t, tt.Result, tree.MostSpecificContains(ip2int(net.ParseIP(tt.IP))))
+	}
+
+	tree = NewCIDRTree()
+	tree.AddCIDR(getCIDR("1.1.1.1/0"), "cool")
+	assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("0.0.0.0"))))
+	assert.Equal(t, "cool", tree.MostSpecificContains(ip2int(net.ParseIP("255.255.255.255"))))
+}
+
 func TestCIDRTree_Match(t *testing.T) {
 func TestCIDRTree_Match(t *testing.T) {
 	tree := NewCIDRTree()
 	tree := NewCIDRTree()
 	tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")
 	tree.AddCIDR(getCIDR("4.1.1.0/32"), "1a")