Nate Brown 1 rok temu
rodzic
commit
8f44f22c37
4 zmienionych plików z 239 dodań i 166 usunięć
  1. 38 8
      cidr/tree4.go
  2. 18 39
      cidr/tree4_test.go
  3. 73 40
      firewall.go
  4. 110 79
      firewall_test.go

+ 38 - 8
cidr/tree4.go

@@ -142,15 +142,22 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
 	return ok, value
 }
 
-// Match finds the most specific match
-// TODO this is exact match
-func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
+type eachFunc[T any] func(T) bool
+
+// EachContains will call a function, passing the value, for each entry until the function returns false or the search is complete
+// The final return value will be true if the provided function returned true
+func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
 	bit := startbit
 	node := tree.root
-	lastNode := node
 
 	for node != nil {
-		lastNode = node
+		if node.hasValue {
+			// If the each func returns true then we can exit the loop
+			if each(node.value) {
+				return true
+			}
+		}
+
 		if ip&bit != 0 {
 			node = node.right
 		} else {
@@ -160,10 +167,33 @@ func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
 		bit >>= 1
 	}
 
-	if bit == 0 && lastNode != nil {
-		value = lastNode.value
-		ok = true
+	return false
+}
+
+// GetCIDR returns the entry added by the most recent matching AddCIDR call
+func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) {
+	bit := startbit
+	node := tree.root
+
+	ip := iputil.Ip2VpnIp(cidr.IP)
+	mask := iputil.Ip2VpnIp(cidr.Mask)
+
+	// Find our last ancestor in the tree
+	for node != nil && bit&mask != 0 {
+		if ip&bit != 0 {
+			node = node.right
+		} else {
+			node = node.left
+		}
+
+		bit = bit >> 1
+	}
+
+	if bit&mask == 0 && node != nil {
+		value = node.value
+		ok = node.hasValue
 	}
+
 	return ok, value
 }
 

+ 18 - 39
cidr/tree4_test.go

@@ -115,35 +115,36 @@ func TestCIDRTree_MostSpecificContains(t *testing.T) {
 	assert.Equal(t, "cool", r)
 }
 
-func TestCIDRTree_Match(t *testing.T) {
+func TestTree4_GetCIDR(t *testing.T) {
 	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("4.1.1.0/32"), "1a")
-	tree.AddCIDR(Parse("4.1.1.1/32"), "1b")
+	tree.AddCIDR(Parse("1.0.0.0/8"), "1")
+	tree.AddCIDR(Parse("2.1.0.0/16"), "2")
+	tree.AddCIDR(Parse("3.1.1.0/24"), "3")
+	tree.AddCIDR(Parse("4.1.1.0/24"), "4a")
+	tree.AddCIDR(Parse("4.1.1.1/32"), "4b")
+	tree.AddCIDR(Parse("4.1.2.1/32"), "4c")
+	tree.AddCIDR(Parse("254.0.0.0/4"), "5")
 
 	tests := []struct {
 		Found  bool
 		Result interface{}
-		IP     string
+		IPNet  *net.IPNet
 	}{
-		{true, "1a", "4.1.1.0"},
-		{true, "1b", "4.1.1.1"},
+		{true, "1", Parse("1.0.0.0/8")},
+		{true, "2", Parse("2.1.0.0/16")},
+		{true, "3", Parse("3.1.1.0/24")},
+		{true, "4a", Parse("4.1.1.0/24")},
+		{true, "4b", Parse("4.1.1.1/32")},
+		{true, "4c", Parse("4.1.2.1/32")},
+		{true, "5", Parse("254.0.0.0/4")},
+		{false, "", Parse("2.0.0.0/8")},
 	}
 
 	for _, tt := range tests {
-		ok, r := tree.Match(iputil.Ip2VpnIp(net.ParseIP(tt.IP)))
+		ok, r := tree.GetCIDR(tt.IPNet)
 		assert.Equal(t, tt.Found, ok)
 		assert.Equal(t, tt.Result, r)
 	}
-
-	tree = NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.1.1/0"), "cool")
-	ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
-
-	ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255")))
-	assert.True(t, ok)
-	assert.Equal(t, "cool", r)
 }
 
 func BenchmarkCIDRTree_Contains(b *testing.B) {
@@ -167,25 +168,3 @@ func BenchmarkCIDRTree_Contains(b *testing.B) {
 		}
 	})
 }
-
-func BenchmarkCIDRTree_Match(b *testing.B) {
-	tree := NewTree4[string]()
-	tree.AddCIDR(Parse("1.1.0.0/16"), "1")
-	tree.AddCIDR(Parse("1.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("192.2.1.1/32"), "1")
-	tree.AddCIDR(Parse("172.2.1.1/32"), "1")
-
-	ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1"))
-	b.Run("found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Match(ip)
-		}
-	})
-
-	ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255"))
-	b.Run("not found", func(b *testing.B) {
-		for i := 0; i < b.N; i++ {
-			tree.Match(ip)
-		}
-	})
-}

+ 73 - 40
firewall.go

@@ -84,6 +84,8 @@ type FirewallConntrack struct {
 	TimerWheel *TimerWheel[firewall.Packet]
 }
 
+// FirewallTable is the entry point for a rule, the evaluation order is:
+// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
 type FirewallTable struct {
 	TCP      firewallPort
 	UDP      firewallPort
@@ -101,24 +103,28 @@ func newFirewallTable() *FirewallTable {
 }
 
 type FirewallCA struct {
-	Any     *FirewallRule
-	CANames map[string]*FirewallRule
-	CAShas  map[string]*FirewallRule
+	Any     *firewallLocalCIDR
+	CANames map[string]*firewallLocalCIDR
+	CAShas  map[string]*firewallLocalCIDR
 }
 
 type FirewallRule struct {
-	// Any makes Hosts, Groups, CIDR and LocalCIDR irrelevant
-	Any       bool
-	Hosts     map[string]struct{}
-	Groups    [][]string
-	CIDR      *cidr.Tree4[struct{}]
-	LocalCIDR *cidr.Tree4[struct{}]
+	// Any makes Hosts, Groups, and CIDR irrelevant
+	Any    bool
+	Hosts  map[string]struct{}
+	Groups [][]string
+	CIDR   *cidr.Tree4[struct{}]
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
 // Plus we can use `-1` for fragment rules
 type firewallPort map[int32]*FirewallCA
 
+type firewallLocalCIDR struct {
+	Any       *FirewallRule
+	LocalCIDR *cidr.Tree4[*FirewallRule]
+}
+
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
 func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
 	//TODO: error on 0 duration
@@ -632,8 +638,8 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 	for i := startPort; i <= endPort; i++ {
 		if _, ok := fp[i]; !ok {
 			fp[i] = &FirewallCA{
-				CANames: make(map[string]*FirewallRule),
-				CAShas:  make(map[string]*FirewallRule),
+				CANames: make(map[string]*firewallLocalCIDR),
+				CAShas:  make(map[string]*firewallLocalCIDR),
 			}
 		}
 
@@ -669,18 +675,15 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 }
 
 func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
-	fr := func() *FirewallRule {
-		return &FirewallRule{
-			Hosts:     make(map[string]struct{}),
-			Groups:    make([][]string, 0),
-			CIDR:      cidr.NewTree4[struct{}](),
-			LocalCIDR: cidr.NewTree4[struct{}](),
+	fl := func() *firewallLocalCIDR {
+		return &firewallLocalCIDR{
+			LocalCIDR: cidr.NewTree4[*FirewallRule](),
 		}
 	}
 
 	if caSha == "" && caName == "" {
 		if fc.Any == nil {
-			fc.Any = fr()
+			fc.Any = fl()
 		}
 
 		return fc.Any.addRule(groups, host, ip, localIp)
@@ -688,7 +691,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
 
 	if caSha != "" {
 		if _, ok := fc.CAShas[caSha]; !ok {
-			fc.CAShas[caSha] = fr()
+			fc.CAShas[caSha] = fl()
 		}
 		err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
 		if err != nil {
@@ -698,7 +701,7 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
 
 	if caName != "" {
 		if _, ok := fc.CANames[caName]; !ok {
-			fc.CANames[caName] = fr()
+			fc.CANames[caName] = fl()
 		}
 		err := fc.CANames[caName].addRule(groups, host, ip, localIp)
 		if err != nil {
@@ -732,18 +735,63 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 	return fc.CANames[s.Details.Name].match(p, c)
 }
 
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, localIp *net.IPNet) error {
+func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
+	fr := func() *FirewallRule {
+		return &FirewallRule{
+			Hosts:  make(map[string]struct{}),
+			Groups: make([][]string, 0),
+			CIDR:   cidr.NewTree4[struct{}](),
+		}
+	}
+
+	if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
+		if fc.Any == nil {
+			fc.Any = fr()
+		}
+
+		return fc.Any.addRule(groups, host, ip)
+	}
+
+	_, efr := fc.LocalCIDR.GetCIDR(localIp)
+	if efr != nil {
+		return efr.addRule(groups, host, ip)
+	}
+
+	nfr := fr()
+	err := nfr.addRule(groups, host, ip)
+	if err != nil {
+		return err
+	}
+
+	fc.LocalCIDR.AddCIDR(localIp, nfr)
+	return nil
+}
+
+func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+	if fc == nil {
+		return false
+	}
+
+	if fc.Any.match(p, c) {
+		return true
+	}
+
+	return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
+		return fr.match(p, c)
+	})
+}
+
+func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
 	if fr.Any {
 		return nil
 	}
 
-	if fr.isAny(groups, host, ip, localIp) {
+	if fr.isAny(groups, host, ip) {
 		fr.Any = true
 		// If it's any we need to wipe out any pre-existing rules to save on memory
 		fr.Groups = make([][]string, 0)
 		fr.Hosts = make(map[string]struct{})
 		fr.CIDR = cidr.NewTree4[struct{}]()
-		fr.LocalCIDR = cidr.NewTree4[struct{}]()
 	} else {
 		if len(groups) > 0 {
 			fr.Groups = append(fr.Groups, groups)
@@ -756,17 +804,13 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, loc
 		if ip != nil {
 			fr.CIDR.AddCIDR(ip, struct{}{})
 		}
-
-		if localIp != nil {
-			fr.LocalCIDR.AddCIDR(localIp, struct{}{})
-		}
 	}
 
 	return nil
 }
 
-func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPNet) bool {
-	if len(groups) == 0 && host == "" && ip == nil && localIp == nil {
+func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
+	if len(groups) == 0 && host == "" && ip == nil {
 		return true
 	}
 
@@ -784,10 +828,6 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip, localIp *net.IPN
 		return true
 	}
 
-	if localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0)) {
-		return true
-	}
-
 	return false
 }
 
@@ -832,13 +872,6 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 		}
 	}
 
-	if fr.LocalCIDR != nil {
-		ok, _ := fr.LocalCIDR.Contains(p.LocalIP)
-		if ok {
-			return true
-		}
-	}
-
 	// No host, group, or cidr matched, bye bye
 	return false
 }

+ 110 - 79
firewall_test.go

@@ -71,37 +71,34 @@ func TestFirewall_AddRule(t *testing.T) {
 
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
 	// An empty rule is any
-	assert.True(t, fw.InRules.TCP[1].Any.Any)
-	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
-	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
+	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Any.Groups)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
-	assert.False(t, fw.InRules.UDP[1].Any.Any)
-	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
-	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
+	assert.False(t, fw.InRules.UDP[1].Any.Any.Any)
+	assert.Contains(t, fw.InRules.UDP[1].Any.Any.Groups[0], "g1")
+	assert.Empty(t, fw.InRules.UDP[1].Any.Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
-	assert.False(t, fw.InRules.ICMP[1].Any.Any)
-	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
-	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
+	assert.False(t, fw.InRules.ICMP[1].Any.Any.Any)
+	assert.Empty(t, fw.InRules.ICMP[1].Any.Any.Groups)
+	assert.Contains(t, fw.InRules.ICMP[1].Any.Any.Hosts, "h1")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
-	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.False(t, fw.OutRules.AnyProto[1].Any.Any.Any)
+	ok, _ := fw.OutRules.AnyProto[1].Any.Any.CIDR.GetCIDR(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
-	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
-	ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
+	ok, fr := fw.OutRules.AnyProto[1].Any.LocalCIDR.GetCIDR(ti)
 	assert.True(t, ok)
+	assert.True(t, fr.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
@@ -114,29 +111,28 @@ func TestFirewall_AddRule(t *testing.T) {
 	// Set any and clear fields
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
-	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
-	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
-	ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
-	assert.True(t, ok)
-	ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
+	ok, fr = fw.OutRules.AnyProto[0].Any.LocalCIDR.GetCIDR(ti)
 	assert.True(t, ok)
+	assert.False(t, fr.Any)
+	assert.Equal(t, []string{"g1", "g2"}, fr.Groups[0])
+	assert.Contains(t, fr.Hosts, "h1")
 
 	// run twice just to make sure
 	//TODO: these ANY rules should clear the CA firewall portion
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
@@ -231,108 +227,89 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	}
 
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
-	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
+	goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
+	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
+	_ = ft.TCP.addRule(100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
+		// This benchmark is showing us the cost of failing to match the protocol
 		c := &cert.NebulaCertificate{}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
 		}
 	})
 
-	b.Run("fail on port", func(b *testing.B) {
+	b.Run("pass proto, fail on port", func(b *testing.B) {
+		// This benchmark is showing us the cost of matching a specific protocol but failing to match the port
 		c := &cert.NebulaCertificate{}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
-		}
-	})
-
-	b.Run("fail all group, name, and cidr", func(b *testing.B) {
-		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "nope",
-				Ips:            []*net.IPNet{ip},
-			},
-		}
-		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
 		}
 	})
 
-	b.Run("pass on group", func(b *testing.B) {
-		c := &cert.NebulaCertificate{
-			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"good-group": {}},
-				Name:           "nope",
-			},
-		}
+	b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
+		c := &cert.NebulaCertificate{}
+		ip, _, _ := net.ParseCIDR("9.254.254.254/32")
+		lip := iputil.Ip2VpnIp(ip)
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
 		}
 	})
 
-	b.Run("pass on name", func(b *testing.B) {
+	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
+		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "good-host",
+				Name:           "nope",
+				Ips:            []*net.IPNet{ip},
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 		}
 	})
 
-	b.Run("pass on ip", func(b *testing.B) {
-		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
+		_, ip, _ := net.ParseCIDR("9.254.254.254/32")
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "good-host",
+				Name:           "nope",
+				Ips:            []*net.IPNet{ip},
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
 		}
 	})
 
-	b.Run("pass on local ip", func(b *testing.B) {
-		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	b.Run("pass on group on any local cidr", func(b *testing.B) {
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "good-host",
+				InvertedGroups: map[string]struct{}{"good-group": {}},
+				Name:           "nope",
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 		}
 	})
 
-	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
-
-	b.Run("pass on ip with any port", func(b *testing.B) {
-		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	b.Run("pass on group on specific local cidr", func(b *testing.B) {
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
-				InvertedGroups: map[string]struct{}{"nope": {}},
-				Name:           "good-host",
+				InvertedGroups: map[string]struct{}{"good-group": {}},
+				Name:           "nope",
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
 		}
 	})
 
-	b.Run("pass on local ip with any port", func(b *testing.B) {
-		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	b.Run("pass on name", func(b *testing.B) {
 		c := &cert.NebulaCertificate{
 			Details: cert.NebulaCertificateDetails{
 				InvertedGroups: map[string]struct{}{"nope": {}},
@@ -340,9 +317,63 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			},
 		}
 		for n := 0; n < b.N; n++ {
-			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
+			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
 		}
 	})
+	//
+	//b.Run("pass on ip", func(b *testing.B) {
+	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	//	c := &cert.NebulaCertificate{
+	//		Details: cert.NebulaCertificateDetails{
+	//			InvertedGroups: map[string]struct{}{"nope": {}},
+	//			Name:           "good-host",
+	//		},
+	//	}
+	//	for n := 0; n < b.N; n++ {
+	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
+	//	}
+	//})
+	//
+	//b.Run("pass on local ip", func(b *testing.B) {
+	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	//	c := &cert.NebulaCertificate{
+	//		Details: cert.NebulaCertificateDetails{
+	//			InvertedGroups: map[string]struct{}{"nope": {}},
+	//			Name:           "good-host",
+	//		},
+	//	}
+	//	for n := 0; n < b.N; n++ {
+	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
+	//	}
+	//})
+	//
+	//_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
+	//
+	//b.Run("pass on ip with any port", func(b *testing.B) {
+	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	//	c := &cert.NebulaCertificate{
+	//		Details: cert.NebulaCertificateDetails{
+	//			InvertedGroups: map[string]struct{}{"nope": {}},
+	//			Name:           "good-host",
+	//		},
+	//	}
+	//	for n := 0; n < b.N; n++ {
+	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
+	//	}
+	//})
+	//
+	//b.Run("pass on local ip with any port", func(b *testing.B) {
+	//	ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
+	//	c := &cert.NebulaCertificate{
+	//		Details: cert.NebulaCertificateDetails{
+	//			InvertedGroups: map[string]struct{}{"nope": {}},
+	//			Name:           "good-host",
+	//		},
+	//	}
+	//	for n := 0; n < b.N; n++ {
+	//		ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
+	//	}
+	//})
 }
 
 func TestFirewall_Drop2(t *testing.T) {