Jelajahi Sumber

add acl policy for static nodes on CE

abhishek9686 3 bulan lalu
induk
melakukan
96e1aa1754
4 mengubah file dengan 276 tambahan dan 104 penghapusan
  1. 6 0
      controllers/egress.go
  2. 266 2
      logic/acls.go
  3. 2 2
      pro/initialize.go
  4. 2 100
      pro/logic/acls.go

+ 6 - 0
controllers/egress.go

@@ -92,6 +92,12 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
 				node.IsGw = true
 				node.IsIngressGateway = true
 				node.IsRelay = true
+				if node.Address.IP != nil {
+					node.IngressDNS = node.Address.IP.String()
+				} else {
+					node.IngressDNS = node.Address6.IP.String()
+				}
+
 				logic.UpsertNode(&node)
 			}
 		}

+ 266 - 2
logic/acls.go

@@ -21,8 +21,272 @@ import (
 // TODO: Write Diff Funcs
 
 var IsNodeAllowedToCommunicate = isNodeAllowedToCommunicate
-var GetStaticNodeIps = func(node models.Node) (ips []net.IP) { return }
-var GetFwRulesOnIngressGateway = func(node models.Node) (rules []models.FwRule) { return }
+
+var GetFwRulesForNodeAndPeerOnGw = getFwRulesForNodeAndPeerOnGw
+
+var GetFwRulesForUserNodesOnGw = func(node models.Node, nodes []models.Node) (rules []models.FwRule) { return }
+
+func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
+	// fetch user access to static clients via policies
+	defer func() {
+		sort.Slice(rules, func(i, j int) bool {
+			if !rules[i].SrcIP.IP.Equal(rules[j].SrcIP.IP) {
+				return string(rules[i].SrcIP.IP.To16()) < string(rules[j].SrcIP.IP.To16())
+			}
+			return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
+		})
+	}()
+	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+	nodes, _ := GetNetworkNodes(node.Network)
+	nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)
+	rules = GetFwRulesForUserNodesOnGw(node, nodes)
+	if defaultDevicePolicy.Enabled {
+		return
+	}
+	for _, nodeI := range nodes {
+		if !nodeI.IsStatic || nodeI.IsUserNode {
+			continue
+		}
+		// if nodeI.StaticNode.IngressGatewayID != node.ID.String() {
+		// 	continue
+		// }
+		for _, peer := range nodes {
+			if peer.StaticNode.ClientID == nodeI.StaticNode.ClientID || peer.IsUserNode {
+				continue
+			}
+			if nodeI.StaticNode.IngressGatewayID != node.ID.String() &&
+				((!peer.IsStatic && peer.ID.String() != node.ID.String()) ||
+					(peer.IsStatic && peer.StaticNode.IngressGatewayID != node.ID.String())) {
+				continue
+			}
+			if peer.IsStatic {
+				peer = peer.StaticNode.ConvertToStaticNode()
+			}
+			var allowedPolicies1 []models.Acl
+			var ok bool
+			if ok, allowedPolicies1 = IsNodeAllowedToCommunicate(nodeI.StaticNode.ConvertToStaticNode(), peer, true); ok {
+				rules = append(rules, GetFwRulesForNodeAndPeerOnGw(nodeI.StaticNode.ConvertToStaticNode(), peer, allowedPolicies1)...)
+			}
+			if ok, allowedPolicies2 := IsNodeAllowedToCommunicate(peer, nodeI.StaticNode.ConvertToStaticNode(), true); ok {
+				rules = append(rules,
+					GetFwRulesForNodeAndPeerOnGw(peer, nodeI.StaticNode.ConvertToStaticNode(),
+						getUniquePolicies(allowedPolicies1, allowedPolicies2))...)
+			}
+		}
+	}
+	return
+}
+
+func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []models.Acl) (rules []models.FwRule) {
+
+	for _, policy := range allowedPolicies {
+		// if static peer dst rule not for ingress node -> skip
+		if node.Address.IP != nil {
+			rules = append(rules, models.FwRule{
+				SrcIP: net.IPNet{
+					IP:   node.Address.IP,
+					Mask: net.CIDRMask(32, 32),
+				},
+				DstIP: net.IPNet{
+					IP:   peer.Address.IP,
+					Mask: net.CIDRMask(32, 32),
+				},
+				Allow: true,
+			})
+		}
+
+		if node.Address6.IP != nil {
+			rules = append(rules, models.FwRule{
+				SrcIP: net.IPNet{
+					IP:   node.Address6.IP,
+					Mask: net.CIDRMask(128, 128),
+				},
+				DstIP: net.IPNet{
+					IP:   peer.Address6.IP,
+					Mask: net.CIDRMask(128, 128),
+				},
+				Allow: true,
+			})
+		}
+		if policy.AllowedDirection == models.TrafficDirectionBi {
+			if node.Address.IP != nil {
+				rules = append(rules, models.FwRule{
+					SrcIP: net.IPNet{
+						IP:   peer.Address.IP,
+						Mask: net.CIDRMask(32, 32),
+					},
+					DstIP: net.IPNet{
+						IP:   node.Address.IP,
+						Mask: net.CIDRMask(32, 32),
+					},
+					Allow: true,
+				})
+			}
+
+			if node.Address6.IP != nil {
+				rules = append(rules, models.FwRule{
+					SrcIP: net.IPNet{
+						IP:   peer.Address6.IP,
+						Mask: net.CIDRMask(128, 128),
+					},
+					DstIP: net.IPNet{
+						IP:   node.Address6.IP,
+						Mask: net.CIDRMask(128, 128),
+					},
+					Allow: true,
+				})
+			}
+		}
+		if len(node.StaticNode.ExtraAllowedIPs) > 0 {
+			for _, additionalAllowedIPNet := range node.StaticNode.ExtraAllowedIPs {
+				_, ipNet, err := net.ParseCIDR(additionalAllowedIPNet)
+				if err != nil {
+					continue
+				}
+				if ipNet.IP.To4() != nil && peer.Address.IP != nil {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   peer.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				} else if peer.Address6.IP != nil {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   peer.Address6.IP,
+							Mask: net.CIDRMask(128, 128),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				}
+
+			}
+
+		}
+		if len(peer.StaticNode.ExtraAllowedIPs) > 0 {
+			for _, additionalAllowedIPNet := range peer.StaticNode.ExtraAllowedIPs {
+				_, ipNet, err := net.ParseCIDR(additionalAllowedIPNet)
+				if err != nil {
+					continue
+				}
+				if ipNet.IP.To4() != nil && node.Address.IP != nil {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   node.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				} else if node.Address6.IP != nil {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   node.Address6.IP,
+							Mask: net.CIDRMask(128, 128),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				}
+
+			}
+
+		}
+
+		// add egress range rules
+		for _, dstI := range policy.Dst {
+			if dstI.ID == models.EgressID {
+
+				e := schema.Egress{ID: dstI.Value}
+				err := e.Get(db.WithContext(context.TODO()))
+				if err != nil {
+					continue
+				}
+				dstI.Value = e.Range
+
+				ip, cidr, err := net.ParseCIDR(dstI.Value)
+				if err == nil {
+					if ip.To4() != nil {
+						if node.Address.IP != nil {
+							rules = append(rules, models.FwRule{
+								SrcIP: net.IPNet{
+									IP:   node.Address.IP,
+									Mask: net.CIDRMask(32, 32),
+								},
+								DstIP: *cidr,
+								Allow: true,
+							})
+						}
+					} else {
+						if node.Address6.IP != nil {
+							rules = append(rules, models.FwRule{
+								SrcIP: net.IPNet{
+									IP:   node.Address6.IP,
+									Mask: net.CIDRMask(128, 128),
+								},
+								DstIP: *cidr,
+								Allow: true,
+							})
+						}
+					}
+
+				}
+			}
+		}
+	}
+
+	return
+}
+
+func getUniquePolicies(policies1, policies2 []models.Acl) []models.Acl {
+	policies1Map := make(map[string]struct{})
+	for _, policy1I := range policies1 {
+		policies1Map[policy1I.ID] = struct{}{}
+	}
+	for i := len(policies2) - 1; i >= 0; i-- {
+		if _, ok := policies1Map[policies2[i].ID]; ok {
+			policies2 = append(policies2[:i], policies2[i+1:]...)
+		}
+	}
+	return policies2
+}
+
+// Sort a slice of net.IP addresses
+func sortIPs(ips []net.IP) {
+	sort.Slice(ips, func(i, j int) bool {
+		ip1, ip2 := ips[i].To16(), ips[j].To16()
+		return string(ip1) < string(ip2) // Compare as byte slices
+	})
+}
+
+func GetStaticNodeIps(node models.Node) (ips []net.IP) {
+	defer func() {
+		sortIPs(ips)
+	}()
+	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
+	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+
+	extclients := GetStaticNodesByNetwork(models.NetworkID(node.Network), false)
+	for _, extclient := range extclients {
+		if extclient.IsUserNode && defaultUserPolicy.Enabled {
+			continue
+		}
+		if !extclient.IsUserNode && defaultDevicePolicy.Enabled {
+			continue
+		}
+		if extclient.StaticNode.Address != "" {
+			ips = append(ips, extclient.StaticNode.AddressIPNet4().IP)
+		}
+		if extclient.StaticNode.Address6 != "" {
+			ips = append(ips, extclient.StaticNode.AddressIPNet6().IP)
+		}
+	}
+	return
+}
+
 var MigrateToGws = func() {
 
 	nodes, err := GetAllNodes()

+ 2 - 2
pro/initialize.go

@@ -155,8 +155,8 @@ func InitPro() {
 	logic.CheckIfAnyPolicyisUniDirectional = proLogic.CheckIfAnyPolicyisUniDirectional
 	logic.MigrateToGws = proLogic.MigrateToGws
 	logic.IsNodeAllowedToCommunicate = proLogic.IsNodeAllowedToCommunicate
-	logic.GetStaticNodeIps = proLogic.GetStaticNodeIps
-	logic.GetFwRulesOnIngressGateway = proLogic.GetFwRulesOnIngressGateway
+	logic.GetFwRulesForNodeAndPeerOnGw = proLogic.GetFwRulesForNodeAndPeerOnGw
+	logic.GetFwRulesForUserNodesOnGw = proLogic.GetFwRulesForUserNodesOnGw
 
 }
 

+ 2 - 100
pro/logic/acls.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"maps"
 	"net"
-	"sort"
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/db"
@@ -25,7 +24,7 @@ ranges should be replaced by egress identifier
 
 */
 
-func getFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules []models.FwRule) {
+func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules []models.FwRule) {
 	defaultUserPolicy, _ := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 	userNodes := logic.GetStaticUserNodesByNetwork(models.NetworkID(node.Network))
 	for _, userNodeI := range userNodes {
@@ -108,58 +107,7 @@ func getFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules []
 	return
 }
 
-func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
-	// fetch user access to static clients via policies
-	defer func() {
-		sort.Slice(rules, func(i, j int) bool {
-			if !rules[i].SrcIP.IP.Equal(rules[j].SrcIP.IP) {
-				return string(rules[i].SrcIP.IP.To16()) < string(rules[j].SrcIP.IP.To16())
-			}
-			return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
-		})
-	}()
-	defaultDevicePolicy, _ := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
-	nodes, _ := logic.GetNetworkNodes(node.Network)
-	nodes = append(nodes, logic.GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)
-	rules = getFwRulesForUserNodesOnGw(node, nodes)
-	if defaultDevicePolicy.Enabled {
-		return
-	}
-	for _, nodeI := range nodes {
-		if !nodeI.IsStatic || nodeI.IsUserNode {
-			continue
-		}
-		// if nodeI.StaticNode.IngressGatewayID != node.ID.String() {
-		// 	continue
-		// }
-		for _, peer := range nodes {
-			if peer.StaticNode.ClientID == nodeI.StaticNode.ClientID || peer.IsUserNode {
-				continue
-			}
-			if nodeI.StaticNode.IngressGatewayID != node.ID.String() &&
-				((!peer.IsStatic && peer.ID.String() != node.ID.String()) ||
-					(peer.IsStatic && peer.StaticNode.IngressGatewayID != node.ID.String())) {
-				continue
-			}
-			if peer.IsStatic {
-				peer = peer.StaticNode.ConvertToStaticNode()
-			}
-			var allowedPolicies1 []models.Acl
-			var ok bool
-			if ok, allowedPolicies1 = IsNodeAllowedToCommunicate(nodeI.StaticNode.ConvertToStaticNode(), peer, true); ok {
-				rules = append(rules, getFwRulesForNodeAndPeerOnGw(nodeI.StaticNode.ConvertToStaticNode(), peer, allowedPolicies1)...)
-			}
-			if ok, allowedPolicies2 := IsNodeAllowedToCommunicate(peer, nodeI.StaticNode.ConvertToStaticNode(), true); ok {
-				rules = append(rules,
-					getFwRulesForNodeAndPeerOnGw(peer, nodeI.StaticNode.ConvertToStaticNode(),
-						getUniquePolicies(allowedPolicies1, allowedPolicies2))...)
-			}
-		}
-	}
-	return
-}
-
-func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []models.Acl) (rules []models.FwRule) {
+func GetFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []models.Acl) (rules []models.FwRule) {
 
 	for _, policy := range allowedPolicies {
 		// if static peer dst rule not for ingress node -> skip
@@ -335,52 +283,6 @@ func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode
 	return
 }
 
-func getUniquePolicies(policies1, policies2 []models.Acl) []models.Acl {
-	policies1Map := make(map[string]struct{})
-	for _, policy1I := range policies1 {
-		policies1Map[policy1I.ID] = struct{}{}
-	}
-	for i := len(policies2) - 1; i >= 0; i-- {
-		if _, ok := policies1Map[policies2[i].ID]; ok {
-			policies2 = append(policies2[:i], policies2[i+1:]...)
-		}
-	}
-	return policies2
-}
-
-func GetStaticNodeIps(node models.Node) (ips []net.IP) {
-	defer func() {
-		sortIPs(ips)
-	}()
-	defaultUserPolicy, _ := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
-	defaultDevicePolicy, _ := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
-
-	extclients := logic.GetStaticNodesByNetwork(models.NetworkID(node.Network), false)
-	for _, extclient := range extclients {
-		if extclient.IsUserNode && defaultUserPolicy.Enabled {
-			continue
-		}
-		if !extclient.IsUserNode && defaultDevicePolicy.Enabled {
-			continue
-		}
-		if extclient.StaticNode.Address != "" {
-			ips = append(ips, extclient.StaticNode.AddressIPNet4().IP)
-		}
-		if extclient.StaticNode.Address6 != "" {
-			ips = append(ips, extclient.StaticNode.AddressIPNet6().IP)
-		}
-	}
-	return
-}
-
-// Sort a slice of net.IP addresses
-func sortIPs(ips []net.IP) {
-	sort.Slice(ips, func(i, j int) bool {
-		ip1, ip2 := ips[i].To16(), ips[j].To16()
-		return string(ip1) < string(ip2) // Compare as byte slices
-	})
-}
-
 func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err error) {
 	switch t.ID {
 	case models.NodeTagID: