Browse Source

Fix ca* checks

Nate Brown 5 years ago
parent
commit
56657065e0
2 changed files with 135 additions and 80 deletions
  1. 93 41
      firewall.go
  2. 42 39
      firewall_test.go

+ 93 - 41
firewall.go

@@ -83,19 +83,23 @@ func newFirewallTable() *FirewallTable {
 	}
 }
 
+type FirewallCA struct {
+	Any     *FirewallRule
+	CANames map[string]*FirewallRule
+	CAShas  map[string]*FirewallRule
+}
+
 type FirewallRule struct {
-	// Any makes Hosts, Groups, and CIDR irrelevant. CAName and CASha still need to be checked
-	Any     bool
-	Hosts   map[string]struct{}
-	Groups  [][]string
-	CIDR    *CIDRTree
-	CANames map[string]struct{}
-	CAShas  map[string]struct{}
+	// Any makes Hosts, Groups, and CIDR irrelevant
+	Any    bool
+	Hosts  map[string]struct{}
+	Groups [][]string
+	CIDR   *CIDRTree
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
 // Plus we can use `-1` for fragment rules
-type firewallPort map[int32]*FirewallRule
+type firewallPort map[int32]*FirewallCA
 
 type FirewallPacket struct {
 	LocalIP    uint32
@@ -182,9 +186,9 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
 
 func NewFirewallFromConfig(nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
 	fw := NewFirewall(
-		c.GetDuration("firewall.conntrack.tcp_timeout", time.Duration(time.Minute*12)),
-		c.GetDuration("firewall.conntrack.udp_timeout", time.Duration(time.Minute*3)),
-		c.GetDuration("firewall.conntrack.default_timeout", time.Duration(time.Minute*10)),
+		c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
+		c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
+		c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
 		nc,
 		//TODO: max_connections
 	)
@@ -499,12 +503,9 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 
 	for i := startPort; i <= endPort; i++ {
 		if _, ok := fp[i]; !ok {
-			fp[i] = &FirewallRule{
-				Groups:  make([][]string, 0),
-				Hosts:   make(map[string]struct{}),
-				CIDR:    NewCIDRTree(),
-				CANames: make(map[string]struct{}),
-				CAShas:  make(map[string]struct{}),
+			fp[i] = &FirewallCA{
+				CANames: make(map[string]*FirewallRule),
+				CAShas:  make(map[string]*FirewallRule),
 			}
 		}
 
@@ -539,15 +540,83 @@ func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCert
 	return fp[fwPortAny].match(p, c, caPool)
 }
 
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
-	if caName != "" {
-		fr.CANames[caName] = struct{}{}
+func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
+	// If there is an any rule then there is no need to establish specific ca rules
+	if fc.Any != nil {
+		return fc.Any.addRule(groups, host, ip)
+	}
+
+	fr := func() *FirewallRule {
+		return &FirewallRule{
+			Hosts:  make(map[string]struct{}),
+			Groups: make([][]string, 0),
+			CIDR:   NewCIDRTree(),
+		}
+	}
+
+	any := false
+	if caSha == "" && caName == "" {
+		any = true
+	}
+
+	if any {
+		if fc.Any == nil {
+			fc.Any = fr()
+		}
+
+		// If it's any we need to wipe out any pre-existing rules to save on memory
+		fc.CAShas = make(map[string]*FirewallRule)
+		fc.CANames = make(map[string]*FirewallRule)
+		return fc.Any.addRule(groups, host, ip)
 	}
 
 	if caSha != "" {
-		fr.CAShas[caSha] = struct{}{}
+		if _, ok := fc.CAShas[caSha]; !ok {
+			fc.CAShas[caSha] = fr()
+		}
+		err := fc.CAShas[caSha].addRule(groups, host, ip)
+		if err != nil {
+			return err
+		}
+	}
+
+	if caName != "" {
+		if _, ok := fc.CANames[caName]; !ok {
+			fc.CANames[caName] = fr()
+		}
+		err := fc.CANames[caName].addRule(groups, host, ip)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+	if fc == nil {
+		return false
+	}
+
+	if fc.Any != nil {
+		return fc.Any.match(p, c)
 	}
 
+	if t, ok := fc.CAShas[c.Details.Issuer]; ok {
+		if t.match(p, c) {
+			return true
+		}
+	}
+
+	s, err := caPool.GetCAForCert(c)
+	if err != nil {
+		return false
+	}
+
+	return fc.CANames[s.Details.Name].match(p, c)
+}
+
+func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
 	if fr.Any {
 		return nil
 	}
@@ -593,28 +662,11 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
 	return false
 }
 
-func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
+func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
 	if fr == nil {
 		return false
 	}
 
-	// CASha and CAName always need to be checked
-	if len(fr.CAShas) > 0 {
-		if _, ok := fr.CAShas[c.Details.Issuer]; !ok {
-			return false
-		}
-	}
-
-	if len(fr.CANames) > 0 {
-		s, err := caPool.GetCAForCert(c)
-		if err != nil {
-			return false
-		}
-		if _, ok := fr.CANames[s.Details.Name]; !ok {
-			return false
-		}
-	}
-
 	// Shortcut path for if groups, hosts, or cidr contained an `any`
 	if fr.Any {
 		return true
@@ -773,7 +825,7 @@ func setTCPRTTTracking(c *conn, p []byte) {
 	ihl := int(p[0]&0x0f) << 2
 
 	// Don't track FIN packets
-	if uint8(p[ihl+13])&tcpFIN != 0 {
+	if p[ihl+13]&tcpFIN != 0 {
 		return
 	}
 
@@ -787,7 +839,7 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
 	}
 
 	ihl := int(p[0]&0x0f) << 2
-	if uint8(p[ihl+13])&tcpACK == 0 {
+	if p[ihl+13]&tcpACK == 0 {
 		return false
 	}
 

+ 42 - 39
firewall_test.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"errors"
+	"fmt"
 	"math"
 	"net"
 	"testing"
@@ -61,37 +62,37 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
 	// Make sure an empty rule creates structure but doesn't allow anything to flow
 	//TODO: ideally an empty rule would return an error
-	assert.False(t, fw.InRules.TCP[1].Any)
-	assert.Empty(t, fw.InRules.TCP[1].Groups)
-	assert.Empty(t, fw.InRules.TCP[1].Hosts)
-	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
-	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
-	assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
+	assert.False(t, fw.InRules.TCP[1].Any.Any)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
+	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
+	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
+	assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
-	assert.False(t, fw.InRules.UDP[1].Any)
-	assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
-	assert.Empty(t, fw.InRules.UDP[1].Hosts)
-	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
-	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
-	assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
+	assert.False(t, fw.InRules.UDP[1].Any.Any)
+	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
+	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
+	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
+	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
+	assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
-	assert.False(t, fw.InRules.ICMP[1].Any)
-	assert.Empty(t, fw.InRules.ICMP[1].Groups)
-	assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
-	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
-	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
-	assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
+	assert.False(t, fw.InRules.ICMP[1].Any.Any)
+	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
+	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
+	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
+	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
+	assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
-	assert.False(t, fw.OutRules.AnyProto[1].Any)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
-	assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
+	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
@@ -104,28 +105,30 @@ func TestFirewall_AddRule(t *testing.T) {
 	// Set any and clear fields
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
-	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
-	assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
-	assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
+	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
+	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
+	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
 
 	// run twice just to make sure
+	//TODO: these ANY rules should clear the CA firewall portion
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
-	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
-	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
-	assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
+	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
+	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
+	assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
+	fmt.Printf("%+v\n", fw.OutRules.AnyProto[0])
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
 	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
 	assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any)
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
 
 	// Test error conditions
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
@@ -209,11 +212,11 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	}
 
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
-	ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
-	ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -281,7 +284,7 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 		}
 	})
 
-	ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
+	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
 
 	b.Run("pass on ip with any port", func(b *testing.B) {
 		ip := ip2int(net.IPv4(172, 1, 1, 1))