Nate Brown преди 1 година
родител
ревизия
f346cf4109
променени са 4 файла, в които са добавени 138 реда и са изтрити 126 реда
  1. 1 1
      cidr/tree4.go
  2. 9 1
      examples/config.yml
  3. 113 92
      firewall.go
  4. 15 32
      firewall_test.go

+ 1 - 1
cidr/tree4.go

@@ -144,7 +144,7 @@ func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
 
 type eachFunc[T any] func(T) bool
 
-// EachContains will call a function, passing the value, for each entry until the function returns false or the search is complete
+// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete
 // The final return value will be true if the provided function returned true
 func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool {
 	bit := startbit

+ 9 - 1
examples/config.yml

@@ -316,7 +316,7 @@ firewall:
 
   # The firewall is default deny. There is no way to write a deny rule.
   # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
-  # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr)
+  # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
   # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
   #   code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
   #   proto: `any`, `tcp`, `udp`, or `icmp`
@@ -325,6 +325,7 @@ firewall:
   #   groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
   #   cidr: a remote CIDR, `0.0.0.0/0` is any.
   #   local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes.
+  #      Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate.
   #   ca_name: An issuing CA name
   #   ca_sha: An issuing CA shasum
 
@@ -346,3 +347,10 @@ firewall:
       groups:
         - laptop
         - home
+
+    # Expose a subnet (unsafe route) to hosts with the group remote_client
+    # This example assume you have a subnet of 192.168.100.1/24 or larger encoded in the certificate
+    - port: 8080
+      proto: tcp
+      group: remote_client
+      local_cidr: 192.168.100.1/24

+ 113 - 92
firewall.go

@@ -58,7 +58,9 @@ type Firewall struct {
 	DefaultTimeout time.Duration //linux: 600s
 
 	// Used to ensure we don't emit local packets for ips we don't own
-	localIps *cidr.Tree4[struct{}]
+	localIps     *cidr.Tree4[struct{}]
+	assignedCIDR *net.IPNet
+	hasSubnets   bool
 
 	rules        string
 	rulesVersion uint16
@@ -103,17 +105,22 @@ func newFirewallTable() *FirewallTable {
 }
 
 type FirewallCA struct {
-	Any     *firewallLocalCIDR
-	CANames map[string]*firewallLocalCIDR
-	CAShas  map[string]*firewallLocalCIDR
+	Any     *FirewallRule
+	CANames map[string]*FirewallRule
+	CAShas  map[string]*FirewallRule
 }
 
 type FirewallRule struct {
 	// Any makes Hosts, Groups, and CIDR irrelevant
-	Any    bool
-	Hosts  map[string]struct{}
-	Groups [][]string
-	CIDR   *cidr.Tree4[struct{}]
+	Any    *firewallLocalCIDR
+	Hosts  map[string]*firewallLocalCIDR
+	Groups []*firewallGroups
+	CIDR   *cidr.Tree4[*firewallLocalCIDR]
+}
+
+type firewallGroups struct {
+	Groups    []string
+	LocalCIDR *firewallLocalCIDR
 }
 
 // Even though ports are uint16, int32 maps are faster for lookup
@@ -121,8 +128,8 @@ type FirewallRule struct {
 type firewallPort map[int32]*FirewallCA
 
 type firewallLocalCIDR struct {
-	Any       *FirewallRule
-	LocalCIDR *cidr.Tree4[*FirewallRule]
+	Any       bool
+	LocalCIDR *cidr.Tree4[struct{}]
 }
 
 // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@@ -145,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 	}
 
 	localIps := cidr.NewTree4[struct{}]()
+	var assignedCIDR *net.IPNet
 	for _, ip := range c.Details.Ips {
-		localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
+		ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
+		localIps.AddCIDR(ipNet, struct{}{})
+
+		if assignedCIDR == nil {
+			// Only grabbing the first one in the cert since any more than that currently has undefined behavior
+			assignedCIDR = ipNet
+		}
 	}
 
 	for _, n := range c.Details.Subnets {
@@ -164,6 +178,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 		UDPTimeout:     UDPTimeout,
 		DefaultTimeout: defaultTimeout,
 		localIps:       localIps,
+		assignedCIDR:   assignedCIDR,
+		hasSubnets:     len(c.Details.Subnets) > 0,
 		l:              l,
 
 		metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -276,7 +292,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
 		return fmt.Errorf("unknown protocol %v", proto)
 	}
 
-	return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
+	return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
 }
 
 // GetRuleHash returns a hash representation of all inbound and outbound rules
@@ -630,7 +646,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
 	return false
 }
 
-func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
 	if startPort > endPort {
 		return fmt.Errorf("start port was lower than end port")
 	}
@@ -638,12 +654,12 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
 	for i := startPort; i <= endPort; i++ {
 		if _, ok := fp[i]; !ok {
 			fp[i] = &FirewallCA{
-				CANames: make(map[string]*firewallLocalCIDR),
-				CAShas:  make(map[string]*firewallLocalCIDR),
+				CANames: make(map[string]*FirewallRule),
+				CAShas:  make(map[string]*FirewallRule),
 			}
 		}
 
-		if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
+		if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
 			return err
 		}
 	}
@@ -674,26 +690,28 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
 	return fp[firewall.PortAny].match(p, c, caPool)
 }
 
-func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
-	fl := func() *firewallLocalCIDR {
-		return &firewallLocalCIDR{
-			LocalCIDR: cidr.NewTree4[*FirewallRule](),
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
+	fr := func() *FirewallRule {
+		return &FirewallRule{
+			Hosts:  make(map[string]*firewallLocalCIDR),
+			Groups: make([]*firewallGroups, 0),
+			CIDR:   cidr.NewTree4[*firewallLocalCIDR](),
 		}
 	}
 
 	if caSha == "" && caName == "" {
 		if fc.Any == nil {
-			fc.Any = fl()
+			fc.Any = fr()
 		}
 
-		return fc.Any.addRule(groups, host, ip, localIp)
+		return fc.Any.addRule(f, groups, host, ip, localIp)
 	}
 
 	if caSha != "" {
 		if _, ok := fc.CAShas[caSha]; !ok {
-			fc.CAShas[caSha] = fl()
+			fc.CAShas[caSha] = fr()
 		}
-		err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
+		err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
 		if err != nil {
 			return err
 		}
@@ -701,9 +719,9 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
 
 	if caName != "" {
 		if _, ok := fc.CANames[caName]; !ok {
-			fc.CANames[caName] = fl()
+			fc.CANames[caName] = fr()
 		}
-		err := fc.CANames[caName].addRule(groups, host, ip, localIp)
+		err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
 		if err != nil {
 			return err
 		}
@@ -735,75 +753,56 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
 	return fc.CANames[s.Details.Name].match(p, c)
 }
 
-func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
-	fr := func() *FirewallRule {
-		return &FirewallRule{
-			Hosts:  make(map[string]struct{}),
-			Groups: make([][]string, 0),
-			CIDR:   cidr.NewTree4[struct{}](),
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
+	flc := func() *firewallLocalCIDR {
+		return &firewallLocalCIDR{
+			LocalCIDR: cidr.NewTree4[struct{}](),
 		}
 	}
 
-	if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
-		if fc.Any == nil {
-			fc.Any = fr()
+	if fr.isAny(groups, host, ip) {
+		if fr.Any == nil {
+			fr.Any = flc()
 		}
 
-		return fc.Any.addRule(groups, host, ip)
+		return fr.Any.addRule(f, localCIDR)
 	}
 
-	_, efr := fc.LocalCIDR.GetCIDR(localIp)
-	if efr != nil {
-		return efr.addRule(groups, host, ip)
-	}
-
-	nfr := fr()
-	err := nfr.addRule(groups, host, ip)
-	if err != nil {
-		return err
-	}
-
-	fc.LocalCIDR.AddCIDR(localIp, nfr)
-	return nil
-}
-
-func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
-	if fc == nil {
-		return false
-	}
-
-	if fc.Any.match(p, c) {
-		return true
-	}
-
-	return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
-		return fr.match(p, c)
-	})
-}
+	if len(groups) > 0 {
+		nlc := flc()
+		err := nlc.addRule(f, localCIDR)
+		if err != nil {
+			return err
+		}
 
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
-	if fr.Any {
-		return nil
+		fr.Groups = append(fr.Groups, &firewallGroups{
+			Groups:    groups,
+			LocalCIDR: nlc,
+		})
 	}
 
-	if fr.isAny(groups, host, ip) {
-		fr.Any = true
-		// If it's any we need to wipe out any pre-existing rules to save on memory
-		fr.Groups = make([][]string, 0)
-		fr.Hosts = make(map[string]struct{})
-		fr.CIDR = cidr.NewTree4[struct{}]()
-	} else {
-		if len(groups) > 0 {
-			fr.Groups = append(fr.Groups, groups)
+	if host != "" {
+		nlc := fr.Hosts[host]
+		if nlc == nil {
+			nlc = flc()
 		}
-
-		if host != "" {
-			fr.Hosts[host] = struct{}{}
+		err := nlc.addRule(f, localCIDR)
+		if err != nil {
+			return err
 		}
+		fr.Hosts[host] = nlc
+	}
 
-		if ip != nil {
-			fr.CIDR.AddCIDR(ip, struct{}{})
+	if ip != nil {
+		_, nlc := fr.CIDR.GetCIDR(ip)
+		if nlc == nil {
+			nlc = flc()
 		}
+		err := nlc.addRule(f, localCIDR)
+		if err != nil {
+			return err
+		}
+		fr.CIDR.AddCIDR(ip, nlc)
 	}
 
 	return nil
@@ -837,7 +836,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 	}
 
 	// Shortcut path for if groups, hosts, or cidr contained an `any`
-	if fr.Any {
+	if fr.Any.match(p, c) {
 		return true
 	}
 
@@ -845,7 +844,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 	for _, sg := range fr.Groups {
 		found := false
 
-		for _, g := range sg {
+		for _, g := range sg.Groups {
 			if _, ok := c.Details.InvertedGroups[g]; !ok {
 				found = false
 				break
@@ -854,26 +853,48 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
 			found = true
 		}
 
-		if found {
+		if found && sg.LocalCIDR.match(p, c) {
 			return true
 		}
 	}
 
 	if fr.Hosts != nil {
-		if _, ok := fr.Hosts[c.Details.Name]; ok {
-			return true
+		if flc, ok := fr.Hosts[c.Details.Name]; ok {
+			if flc.match(p, c) {
+				return true
+			}
 		}
 	}
 
-	if fr.CIDR != nil {
-		ok, _ := fr.CIDR.Contains(p.RemoteIP)
-		if ok {
-			return true
+	return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
+		return flc.match(p, c)
+	})
+}
+
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
+	if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
+		if !f.hasSubnets {
+			flc.Any = true
+			return nil
 		}
+		localIp = f.assignedCIDR
 	}
 
-	// No host, group, or cidr matched, bye bye
-	return false
+	flc.LocalCIDR.AddCIDR(localIp, struct{}{})
+	return nil
+}
+
+func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
+	if flc == nil {
+		return false
+	}
+
+	if flc.Any {
+		return true
+	}
+
+	ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
+	return ok
 }
 
 type rule struct {

+ 15 - 32
firewall_test.go

@@ -72,33 +72,32 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
-	assert.Empty(t, fw.InRules.TCP[1].Any.Any.Groups)
-	assert.Empty(t, fw.InRules.TCP[1].Any.Any.Hosts)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
+	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
-	assert.False(t, fw.InRules.UDP[1].Any.Any.Any)
-	assert.Contains(t, fw.InRules.UDP[1].Any.Any.Groups[0], "g1")
-	assert.Empty(t, fw.InRules.UDP[1].Any.Any.Hosts)
+	assert.Nil(t, fw.InRules.UDP[1].Any.Any)
+	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
+	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
-	assert.False(t, fw.InRules.ICMP[1].Any.Any.Any)
-	assert.Empty(t, fw.InRules.ICMP[1].Any.Any.Groups)
-	assert.Contains(t, fw.InRules.ICMP[1].Any.Any.Hosts, "h1")
+	assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
+	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
+	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
-	assert.False(t, fw.OutRules.AnyProto[1].Any.Any.Any)
-	ok, _ := fw.OutRules.AnyProto[1].Any.Any.CIDR.GetCIDR(ti)
+	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
+	ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
 	assert.True(t, ok)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
-	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
-	ok, fr := fw.OutRules.AnyProto[1].Any.LocalCIDR.GetCIDR(ti)
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
+	ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
 	assert.True(t, ok)
-	assert.True(t, fr.Any)
 
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
@@ -108,23 +107,6 @@ func TestFirewall_AddRule(t *testing.T) {
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
 	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
 
-	// Set any and clear fields
-	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
-	ok, fr = fw.OutRules.AnyProto[0].Any.LocalCIDR.GetCIDR(ti)
-	assert.True(t, ok)
-	assert.False(t, fr.Any)
-	assert.Equal(t, []string{"g1", "g2"}, fr.Groups[0])
-	assert.Contains(t, fr.Hosts, "h1")
-
-	// run twice just to make sure
-	//TODO: these ANY rules should clear the CA firewall portion
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
-	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
-	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Groups)
-	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Any.Hosts)
-
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
@@ -222,14 +204,15 @@ func TestFirewall_Drop(t *testing.T) {
 }
 
 func BenchmarkFirewallTable_match(b *testing.B) {
+	f := &Firewall{}
 	ft := FirewallTable{
 		TCP: firewallPort{},
 	}
 
 	_, n, _ := net.ParseCIDR("172.1.1.1/32")
 	goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
-	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
-	_ = ft.TCP.addRule(100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {