Browse Source

Actual fix for the real issue with tests

Nate Brown 5 năm trước cách đây
mục cha
commit
2d8a8143de
2 tập tin đã thay đổi với 100 bổ sung18 xóa
  1. 6 15
      firewall.go
  2. 94 3
      firewall_test.go

+ 6 - 15
firewall.go

@@ -541,11 +541,6 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 }
 
 func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
-	// If there is an any rule then there is no need to establish specific ca rules
-	if fc.Any != nil {
-		return fc.Any.addRule(groups, host, ip)
-	}
-
 	fr := func() *FirewallRule {
 		return &FirewallRule{
 			Hosts:  make(map[string]struct{}),
@@ -554,19 +549,11 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caNam
 		}
 	}
 
-	any := false
 	if caSha == "" && caName == "" {
-		any = true
-	}
-
-	if any {
 		if fc.Any == nil {
 			fc.Any = fr()
 		}
 
-		// If it's any we need to wipe out any pre-existing rules to save on memory
-		fc.CAShas = make(map[string]*FirewallRule)
-		fc.CANames = make(map[string]*FirewallRule)
 		return fc.Any.addRule(groups, host, ip)
 	}
 
@@ -598,8 +585,8 @@ func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool
 		return false
 	}
 
-	if fc.Any != nil {
-		return fc.Any.match(p, c)
+	if fc.Any.match(p, c) {
+		return true
 	}
 
 	if t, ok := fc.CAShas[c.Details.Issuer]; ok {
@@ -645,6 +632,10 @@ func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) err
 }
 
 func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
+	if len(groups) == 0 && host == "" && ip == nil {
+		return true
+	}
+
 	for _, group := range groups {
 		if group == "any" {
 			return true

+ 94 - 3
firewall_test.go

@@ -64,9 +64,8 @@ func TestFirewall_AddRule(t *testing.T) {
 	_, ti, _ := net.ParseCIDR("1.2.3.4/32")
 
 	assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
-	// Make sure an empty rule creates structure but doesn't allow anything to flow
-	//TODO: ideally an empty rule would return an error
-	assert.False(t, fw.InRules.TCP[1].Any.Any)
+	// An empty rule is any
+	assert.True(t, fw.InRules.TCP[1].Any.Any)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
 	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
@@ -182,6 +181,7 @@ func TestFirewall_Drop(t *testing.T) {
 	// Drop outbound
 	assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
 	// Allow inbound
+	resetConntrack(fw)
 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 	// Allow outbound because conntrack
 	assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
@@ -368,9 +368,94 @@ func TestFirewall_Drop2(t *testing.T) {
 	// h1/c1 lacks the proper groups
 	assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
 	// c has the proper groups
+	resetConntrack(fw)
 	assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
 }
 
+func TestFirewall_Drop3(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
+	p := FirewallPacket{
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		1,
+		1,
+		fwProtoUDP,
+		false,
+	}
+
+	ipNet := net.IPNet{
+		IP:   net.IPv4(1, 2, 3, 4),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+
+	c := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name: "host-owner",
+			Ips:  []*net.IPNet{&ipNet},
+		},
+	}
+
+	c1 := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:   "host1",
+			Ips:    []*net.IPNet{&ipNet},
+			Issuer: "signer-sha-bad",
+		},
+	}
+	h1 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c1,
+		},
+	}
+	h1.CreateRemoteCIDR(&c1)
+
+	c2 := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:   "host2",
+			Ips:    []*net.IPNet{&ipNet},
+			Issuer: "signer-sha",
+		},
+	}
+	h2 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c2,
+		},
+	}
+	h2.CreateRemoteCIDR(&c2)
+
+	c3 := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:   "host3",
+			Ips:    []*net.IPNet{&ipNet},
+			Issuer: "signer-sha-bad",
+		},
+	}
+	h3 := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c3,
+		},
+	}
+	h3.CreateRemoteCIDR(&c3)
+
+	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
+	cp := cert.NewCAPool()
+
+	// c1 should pass because host match
+	assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
+	// c2 should pass because ca sha match
+	resetConntrack(fw)
+	assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
+	// c3 should fail because no match
+	resetConntrack(fw)
+	assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
+}
+
 func BenchmarkLookup(b *testing.B) {
 	ml := func(m map[string]struct{}, a [][]string) {
 		for n := 0; n < b.N; n++ {
@@ -769,3 +854,9 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 	mf.nextCallReturn = nil
 	return err
 }
+
+func resetConntrack(fw *Firewall) {
+	fw.connMutex.Lock()
+	fw.Conns = map[FirewallPacket]*conn{}
+	fw.connMutex.Unlock()
+}