Browse Source

Merge pull request #113 from slackhq/fw-ca

Fixes the issues with caSha and caName
Nathan Brown 5 years ago
parent
commit
e465b13045
3 changed files with 254 additions and 99 deletions
  1. 1 1
      examples/config.yml
  2. 84 41
      firewall.go
  3. 169 57
      firewall_test.go

+ 1 - 1
examples/config.yml

@@ -141,7 +141,7 @@ firewall:
 
   # The firewall is default deny. There is no way to write a deny rule.
   # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
-  # Logical evaluation is roughly: port AND proto AND ca_sha AND ca_name AND (host OR group OR groups OR cidr)
+  # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr)
   # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
   #   code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
   #   proto: `any`, `tcp`, `udp`, or `icmp`

+ 84 - 41
firewall.go

@@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable {
 	}
 }
 
+type FirewallCA struct {
+	Any     *FirewallRule
+	CANames map[string]*FirewallRule
+	CAShas  map[string]*FirewallRule
+}
+
 type FirewallRule struct {
-	// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked
-	Any     bool
-	Hosts   map[string]struct{}
-	Groups  [][]string
-	CIDR    *CIDRTree
-	CANames map[string]struct{}
-	CAShas  map[string]struct{}
+	// Any makes Hosts, Groups, and CIDR irrelevant
+	Any    bool
+	Hosts  map[string]struct{}
+	Groups [][]string
+	CIDR   *CIDRTree
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
 // Plus we can use `-1` for fragment rules
-type firewallPort map[int32]*FirewallRule
+type firewallPort map[int32]*FirewallCA
 
 type FirewallPacket struct {
 	LocalIP    uint32
@@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
 
 func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
 	fw := NewFirewall(
-		c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)),
-		c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)),
-		c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)),
+		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
+		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
+		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
 		nc,
 		//TODO: max_connections
 	)
@@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 
 	for i := startPort; i <= endPort; i++ {
 		if _, ok := fp[i]; !ok {
-			fp[i] = &FirewallRule{
-				Groups:  make([][]string, 0),
-				Hosts:   make(map[string]struct{}),
-				CIDR:    NewCIDRTree(),
-				CANames: make(map[string]struct{}),
-				CAShas:  make(map[string]struct{}),
+			fp[i] = &FirewallCA{
+				CANames: make(map[string]*FirewallRule),
+				CAShas:  make(map[string]*FirewallRule),
 			}
 		}
 
@@ -539,15 +540,70 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 	return fp[fwPortAny].match(p, c, caPool)
 }
 
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
-	if caName != "" {
-		fr.CANames[caName] = struct{}{}
+func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
+	fr := func() *FirewallRule {
+		return &FirewallRule{
+			Hosts:  make(map[string]struct{}),
+			Groups: make([][]string, 0),
+			CIDR:   NewCIDRTree(),
+		}
+	}
+
+	if caSha == "" && caName == "" {
+		if fc.Any == nil {
+			fc.Any = fr()
+		}
+
+		return fc.Any.addRule(groups, host, ip)
 	}
 
 	if caSha != "" {
-		fr.CAShas[caSha] = struct{}{}
+		if _, ok := fc.CAShas[caSha]; !ok {
+			fc.CAShas[caSha] = fr()
+		}
+		err := fc.CAShas[caSha].addRule(groups, host, ip)
+		if err != nil {
+			return err
+		}
+	}
+
+	if caName != "" {
+		if _, ok := fc.CANames[caName]; !ok {
+			fc.CANames[caName] = fr()
+		}
+		err := fc.CANames[caName].addRule(groups, host, ip)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+	if fc == nil {
+		return false
 	}
 
+	if fc.Any.match(p, c) {
+		return true
+	}
+
+	if t, ok := fc.CAShas[c.Details.Issuer]; ok {
+		if t.match(p, c) {
+			return true
+		}
+	}
+
+	s, err := caPool.GetCAForCert(c)
+	if err != nil {
+		return false
+	}
+
+	return fc.CANames[s.Details.Name].match(p, c)
+}
+
+func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
 	if fr.Any {
 		return nil
 	}
@@ -576,6 +632,10 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caN
 }
 
 func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
+	if len(groups) == 0 && host == "" && ip == nil {
+		return true
+	}
+
 	for _, group := range groups {
 		if group == "any" {
 			return true
@@ -593,28 +653,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 	return false
 }
 
-func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
 	if fr == nil {
 		return false
 	}
 
-	// CASha and CAName always need to be checked
-	if len(fr.CAShas) > 0 {
-		if _, ok := fr.CAShas[c.Details.Issuer]; !ok {
-			return false
-		}
-	}
-
-	if len(fr.CANames) > 0 {
-		s, err := caPool.GetCAForCert(c)
-		if err != nil {
-			return false
-		}
-		if _, ok := fr.CANames[s.Details.Name]; !ok {
-			return false
-		}
-	}
-
 	// Shortcut path for if groups, hosts, or cidr contained an `any`
 	if fr.Any {
 		return true
@@ -773,7 +816,7 @@ func setTCPRTTTracking(c *conn, p []byte) {
 	ihl := int(p[0]&0x0f) << 2
 
 	// Don't track FIN packets
-	if uint8(p[ihl+13])&tcpFIN != 0 {
+	if p[ihl+13]&tcpFIN != 0 {
 		return
 	}
 
@@ -787,7 +830,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
 	}
 
 	ihl := int(p[0]&0x0f) << 2
-	if uint8(p[ihl+13])&tcpACK == 0 {
+	if p[ihl+13]&tcpACK == 0 {
 		return false
 	}
 

+ 169 - 57
firewall_test.go

@@ -51,6 +51,11 @@ func TestNewFirewall(t *testing.T) {
 }
 
 func TestFirewall_AddRule(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
 	c := &cert.NebulaCertificate{}
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.NotNil(t, fw.InRules)
@@ -59,39 +64,38 @@ func TestFirewall_AddRule(t *testing.T) {
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 
 	assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
-	// Make sure an empty rule creates structure but doesn't allow anything to flow
-	//TODO: ideally an empty rule would return an error
-	assert.False(t, fw.InRules.TCP[1].Any)
-	assert.Empty(t, fw.InRules.TCP[1].Groups)
-	assert.Empty(t, fw.InRules.TCP[1].Hosts)
-	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
-	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
-	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
+	// An empty rule is any
+	assert.True(t, fw.InRules.TCP[1].Any.Any)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
+	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
+	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
+	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
-	assert.False(t, fw.InRules.UDP[1].Any)
-	assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
-	assert.Empty(t, fw.InRules.UDP[1].Hosts)
-	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
-	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
-	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
+	assert.False(t, fw.InRules.UDP[1].Any.Any)
+	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
+	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
+	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
+	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
+	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
-	assert.False(t, fw.InRules.ICMP[1].Any)
-	assert.Empty(t, fw.InRules.ICMP[1].Groups)
-	assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
-	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
-	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
-	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
+	assert.False(t, fw.InRules.ICMP[1].Any.Any)
+	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
+	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
+	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
+	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
+	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
-	assert.False(t, fw.OutRules.AnyProto[1].Any)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
+	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
@@ -104,28 +108,29 @@ func TestFirewall_AddRule(t *testing.T) {
 	// Set any and clear fields
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
-	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
-	assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
+	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
+	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
+	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
 
 	// run twice just to make sure
+	//TODO: these ANY rules should clear the CA firewall portion
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
-	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
-	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
-	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
@@ -134,6 +139,11 @@ func TestFirewall_AddRule(t *testing.T) {
 }
 
 func TestFirewall_Drop(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
 		ip2int(net.IPv4(1, 2, 3, 4)),
@@ -150,10 +160,11 @@ func TestFirewall_Drop(t *testing.T) {
 
 	c := cert.NebulaCertificate{
 		Details: cert.NebulaCertificateDetails{
-			Name:   "host1",
-			Ips:    []*net.IPNet{&ipNet},
-			Groups: []string{"default-group"},
-			Issuer: "signer-shasum",
+			Name:           "host1",
+			Ips:            []*net.IPNet{&ipNet},
+			Groups:         []string{"default-group"},
+			InvertedGroups: map[string]struct{}{"default-group": {}},
+			Issuer:         "signer-shasum",
 		},
 	}
 	h := HostInfo{
@@ -170,6 +181,7 @@ func TestFirewall_Drop(t *testing.T) {
 	// Drop outbound
 	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
 	// Allow inbound
+	resetConntrack(fw)
 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 	// Allow outbound because conntrack
 	assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
@@ -180,27 +192,31 @@ func TestFirewall_Drop(t *testing.T) {
 	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
 	p.RemoteIP = oldRemote
 
-	// test caSha assertions true
-	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
-	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
-
-	// test caSha assertions false
+	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
 	assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
 
-	// test caName true
-	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
+	// test caSha doesn't drop on match
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 
-	// test caName false
+	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
-	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
 	assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
+
+	// test caName doesn't drop on match
+	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
+	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
+	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 }
 
 func BenchmarkFirewallTable_match(b *testing.B) {
@@ -209,11 +225,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	}
 
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -281,7 +297,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		}
 	})
 
-	ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
 
 	b.Run("pass on ip with any port", func(b *testing.B) {
 		ip := ip2int(net.IPv4(172, 1, 1, 1))
@@ -298,6 +314,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 }
 
 func TestFirewall_Drop2(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
 	p := FirewallPacket{
 		ip2int(net.IPv4(1, 2, 3, 4)),
 		ip2int(net.IPv4(1, 2, 3, 4)),
@@ -347,9 +368,94 @@ func TestFirewall_Drop2(t *testing.T) {
 	// h1/c1 lacks the proper groups
 	assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
 	// c has the proper groups
+	resetConntrack(fw)
 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 }
 
+func TestFirewall_Drop3(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
+	p := FirewallPacket{
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		1,
+		1,
+		fwProtoUDP,
+		false,
+	}
+
+	ipNet := net.IPNet{
+		IP:   net.IPv4(1, 2, 3, 4),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+
+	c := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name: "host-owner",
+			Ips:  []*net.IPNet{&ipNet},
+		},
+	}
+
+	c1 := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:   "host1",
+			Ips:    []*net.IPNet{&ipNet},
+			Issuer: "signer-sha-bad",
+		},
+	}
+	h1 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c1,
+		},
+	}
+	h1.CreateRemoteCIDR(&c1)
+
+	c2 := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:   "host2",
+			Ips:    []*net.IPNet{&ipNet},
+			Issuer: "signer-sha",
+		},
+	}
+	h2 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c2,
+		},
+	}
+	h2.CreateRemoteCIDR(&c2)
+
+	c3 := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:   "host3",
+			Ips:    []*net.IPNet{&ipNet},
+			Issuer: "signer-sha-bad",
+		},
+	}
+	h3 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c3,
+		},
+	}
+	h3.CreateRemoteCIDR(&c3)
+
+	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
+	cp := cert.NewCAPool()
+
+	// c1 should pass because host match
+	assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
+	// c2 should pass because ca sha match
+	resetConntrack(fw)
+	assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
+	// c3 should fail because no match
+	resetConntrack(fw)
+	assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
+}
+
 func BenchmarkLookup(b *testing.B) {
 	ml := func(m map[string]struct{}, a [][]string) {
 		for n := 0; n < b.N; n++ {
@@ -748,3 +854,9 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 	mf.nextCallReturn = nil
 	return err
 }
+
+func resetConntrack(fw *Firewall) {
+	fw.connMutex.Lock()
+	fw.Conns = map[FirewallPacket]*conn{}
+	fw.connMutex.Unlock()
+}