Browse Source

fix ce acl comms

abhishek9686 3 months ago
parent
commit
6a18b6e2b7
3 changed files with 197 additions and 89 deletions
  1. 180 16
      logic/acls.go
  2. 1 1
      logic/peers.go
  3. 16 72
      pro/logic/acls.go

+ 180 - 16
logic/acls.go

@@ -6,6 +6,7 @@ import (
 	"errors"
 	"fmt"
 	"maps"
+	"net"
 	"sort"
 	"sync"
 	"time"
@@ -107,12 +108,47 @@ var CheckIfAnyActiveEgressPolicy = func(targetNode models.Node) bool {
 	targetNodeTags[models.TagID(targetNode.ID.String())] = struct{}{}
 	targetNodeTags["*"] = struct{}{}
 	acls, _ := ListAclsByNetwork(models.NetworkID(targetNode.Network))
+	for _, acl := range acls {
+		if !acl.Enabled || acl.RuleType != models.DevicePolicy {
+			continue
+		}
+		srcTags := ConvAclTagToValueMap(acl.Src)
+		for _, dst := range acl.Dst {
+			if dst.ID == models.EgressID {
+				e := schema.Egress{ID: dst.Value}
+				err := e.Get(db.WithContext(context.TODO()))
+				if err == nil && e.Status {
+					for nodeTag := range targetNodeTags {
+						if _, ok := srcTags[nodeTag.String()]; ok {
+							return true
+						}
+						if _, ok := srcTags[targetNode.ID.String()]; ok {
+							return true
+						}
+					}
+				}
+			}
+		}
+	}
+	return false
+}
+
+var GetAclRulesForNode = func(targetnodeI *models.Node) (rules map[string]models.AclRule) {
+	targetnode := *targetnodeI
+
+	rules = make(map[string]models.AclRule)
+
+	acls := ListDevicePolicies(models.NetworkID(targetnode.Network))
+	targetNodeTags := make(map[models.TagID]struct{})
+	targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{}
+	targetNodeTags["*"] = struct{}{}
 	for _, acl := range acls {
 		if !acl.Enabled {
 			continue
 		}
 		srcTags := ConvAclTagToValueMap(acl.Src)
 		dstTags := ConvAclTagToValueMap(acl.Dst)
+		nodes := []models.Node{}
 		for _, dst := range acl.Dst {
 			if dst.ID == models.EgressID {
 				e := schema.Egress{ID: dst.Value}
@@ -121,37 +157,165 @@ var CheckIfAnyActiveEgressPolicy = func(targetNode models.Node) bool {
 					for nodeID := range e.Nodes {
 						dstTags[nodeID] = struct{}{}
 					}
-					dstTags[e.Range] = struct{}{}
 				}
 			}
 		}
+		_, srcAll := srcTags["*"]
+		_, dstAll := dstTags["*"]
+		aclRule := models.AclRule{
+			ID:              acl.ID,
+			AllowedProtocol: acl.Proto,
+			AllowedPorts:    acl.Port,
+			Direction:       acl.AllowedDirection,
+			Allowed:         true,
+		}
 		for nodeTag := range targetNodeTags {
-			if acl.RuleType == models.DevicePolicy && acl.AllowedDirection == models.TrafficDirectionBi {
-				if _, ok := srcTags[nodeTag.String()]; ok {
-					return true
+			if acl.AllowedDirection == models.TrafficDirectionBi {
+				var existsInSrcTag bool
+				var existsInDstTag bool
+
+				if _, ok := srcTags[nodeTag.String()]; ok || srcAll {
+					existsInSrcTag = true
 				}
-				if _, ok := srcTags[targetNode.ID.String()]; ok {
-					return true
+				if _, ok := srcTags[targetnode.ID.String()]; ok || srcAll {
+					existsInSrcTag = true
+				}
+				if _, ok := dstTags[nodeTag.String()]; ok || dstAll {
+					existsInDstTag = true
+				}
+				if _, ok := dstTags[targetnode.ID.String()]; ok || dstAll {
+					existsInDstTag = true
 				}
-			}
 
-			if _, ok := dstTags[nodeTag.String()]; ok {
-				return true
-			}
-			if _, ok := dstTags[targetNode.ID.String()]; ok {
-				return true
+				if existsInSrcTag /* && !existsInDstTag*/ {
+					// get all dst tags
+					for dst := range dstTags {
+						if dst == nodeTag.String() {
+							continue
+						}
+						// Get peers in the tags and add allowed rules
+						if dst != targetnode.ID.String() {
+							node, err := GetNodeByID(dst)
+							if err == nil {
+								nodes = append(nodes, node)
+							}
+						}
+					}
+
+					for _, node := range nodes {
+						if node.ID == targetnode.ID {
+							continue
+						}
+						if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+							continue
+						}
+						if node.Address.IP != nil {
+							aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+						}
+						if node.Address6.IP != nil {
+							aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+						}
+						if node.IsStatic && node.StaticNode.Address != "" {
+							aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+						}
+						if node.IsStatic && node.StaticNode.Address6 != "" {
+							aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+						}
+					}
+
+				}
+				if existsInDstTag /*&& !existsInSrcTag*/ {
+					// get all src tags
+					for src := range srcTags {
+						if src == nodeTag.String() {
+							continue
+						}
+						// Get peers in the tags and add allowed rules
+						if src != targetnode.ID.String() {
+							node, err := GetNodeByID(src)
+							if err == nil {
+								nodes = append(nodes, node)
+							}
+						}
+					}
+					for _, node := range nodes {
+						if node.ID == targetnode.ID {
+							continue
+						}
+						if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+							continue
+						}
+						if node.Address.IP != nil {
+							aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+						}
+						if node.Address6.IP != nil {
+							aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+						}
+						if node.IsStatic && node.StaticNode.Address != "" {
+							aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+						}
+						if node.IsStatic && node.StaticNode.Address6 != "" {
+							aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+						}
+					}
+
+				}
 			}
+
+		}
+
+		if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
+			aclRule.IPList = UniqueIPNetList(aclRule.IPList)
+			aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
+			rules[acl.ID] = aclRule
 		}
 	}
-	return false
+	return rules
 }
 
-var GetAclRulesForNode = func(targetnodeI *models.Node) (rules map[string]models.AclRule) {
+var GetEgressRulesForNode = func(targetnode models.Node) (rules map[string]models.AclRule) {
 	return
 }
 
-var GetEgressRulesForNode = func(targetnode models.Node) (rules map[string]models.AclRule) {
-	return
+// Compare two IPs and return true if ip1 < ip2
+func lessIP(ip1, ip2 net.IP) bool {
+	ip1 = ip1.To16() // Ensure IPv4 is converted to IPv6-mapped format
+	ip2 = ip2.To16()
+	return string(ip1) < string(ip2)
+}
+
+// Sort by IP first, then by prefix length
+func sortIPNets(ipNets []net.IPNet) {
+	sort.Slice(ipNets, func(i, j int) bool {
+		ip1, ip2 := ipNets[i].IP, ipNets[j].IP
+		mask1, _ := ipNets[i].Mask.Size()
+		mask2, _ := ipNets[j].Mask.Size()
+
+		// Compare IPs first
+		if ip1.Equal(ip2) {
+			return mask1 < mask2 // If same IP, sort by subnet mask size
+		}
+		return lessIP(ip1, ip2)
+	})
+}
+
+func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
+	uniqueMap := make(map[string]net.IPNet)
+
+	for _, ipnet := range ipnets {
+		key := ipnet.String() // Uses CIDR notation as a unique key
+		if _, exists := uniqueMap[key]; !exists {
+			uniqueMap[key] = ipnet
+		}
+	}
+
+	// Convert map back to slice
+	uniqueList := make([]net.IPNet, 0, len(uniqueMap))
+	for _, ipnet := range uniqueMap {
+		uniqueList = append(uniqueList, ipnet)
+	}
+	sortIPNets(uniqueList)
+	return uniqueList
 }
 
 func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err error) {

+ 1 - 1
logic/peers.go

@@ -203,7 +203,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		}
 		defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 		defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
-
+		fmt.Println("====> Checking for ", host.Name, CheckIfAnyActiveEgressPolicy(node), CheckIfNodeHasAccessToAllResources(&node))
 		if (defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) ||
 			(!CheckIfAnyPolicyisUniDirectional(node) && !CheckIfAnyActiveEgressPolicy(node)) ||
 			CheckIfNodeHasAccessToAllResources(&node) {

+ 16 - 72
pro/logic/acls.go

@@ -5,7 +5,6 @@ import (
 	"errors"
 	"maps"
 	"net"
-	"sort"
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/db"
@@ -888,12 +887,12 @@ func getUserAclRulesForNode(targetnode *models.Node,
 			if aclRule, ok := rules[acl.ID]; ok {
 				aclRule.IPList = append(aclRule.IPList, r.IPList...)
 				aclRule.IP6List = append(aclRule.IP6List, r.IP6List...)
-				aclRule.IPList = UniqueIPNetList(aclRule.IPList)
-				aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
+				aclRule.IPList = logic.UniqueIPNetList(aclRule.IPList)
+				aclRule.IP6List = logic.UniqueIPNetList(aclRule.IP6List)
 				rules[acl.ID] = aclRule
 			} else {
-				r.IPList = UniqueIPNetList(r.IPList)
-				r.IP6List = UniqueIPNetList(r.IP6List)
+				r.IPList = logic.UniqueIPNetList(r.IPList)
+				r.IP6List = logic.UniqueIPNetList(r.IP6List)
 				rules[acl.ID] = r
 			}
 		}
@@ -920,38 +919,24 @@ func CheckIfAnyActiveEgressPolicy(targetNode models.Node) bool {
 	targetNodeTags["*"] = struct{}{}
 	acls, _ := logic.ListAclsByNetwork(models.NetworkID(targetNode.Network))
 	for _, acl := range acls {
-		if !acl.Enabled {
+		if !acl.Enabled || acl.RuleType != models.DevicePolicy {
 			continue
 		}
 		srcTags := logic.ConvAclTagToValueMap(acl.Src)
-		dstTags := logic.ConvAclTagToValueMap(acl.Dst)
 		for _, dst := range acl.Dst {
 			if dst.ID == models.EgressID {
 				e := schema.Egress{ID: dst.Value}
 				err := e.Get(db.WithContext(context.TODO()))
 				if err == nil && e.Status {
-					for nodeID := range e.Nodes {
-						dstTags[nodeID] = struct{}{}
+					for nodeTag := range targetNodeTags {
+						if _, ok := srcTags[nodeTag.String()]; ok {
+							return true
+						}
+						if _, ok := srcTags[targetNode.ID.String()]; ok {
+							return true
+						}
 					}
-					dstTags[e.Range] = struct{}{}
-				}
-			}
-		}
-		for nodeTag := range targetNodeTags {
-			if acl.RuleType == models.DevicePolicy && acl.AllowedDirection == models.TrafficDirectionBi {
-				if _, ok := srcTags[nodeTag.String()]; ok {
-					return true
 				}
-				if _, ok := srcTags[targetNode.ID.String()]; ok {
-					return true
-				}
-			}
-
-			if _, ok := dstTags[nodeTag.String()]; ok {
-				return true
-			}
-			if _, ok := dstTags[targetNode.ID.String()]; ok {
-				return true
 			}
 		}
 	}
@@ -1229,8 +1214,8 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu
 		}
 
 		if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
-			aclRule.IPList = UniqueIPNetList(aclRule.IPList)
-			aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
+			aclRule.IPList = logic.UniqueIPNetList(aclRule.IPList)
+			aclRule.IP6List = logic.UniqueIPNetList(aclRule.IP6List)
 			rules[acl.ID] = aclRule
 		}
 	}
@@ -1462,8 +1447,8 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
 
 		}
 		if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
-			aclRule.IPList = UniqueIPNetList(aclRule.IPList)
-			aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
+			aclRule.IPList = logic.UniqueIPNetList(aclRule.IPList)
+			aclRule.IP6List = logic.UniqueIPNetList(aclRule.IP6List)
 			rules[acl.ID] = aclRule
 		}
 
@@ -1471,47 +1456,6 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
 	return
 }
 
-// Compare two IPs and return true if ip1 < ip2
-func lessIP(ip1, ip2 net.IP) bool {
-	ip1 = ip1.To16() // Ensure IPv4 is converted to IPv6-mapped format
-	ip2 = ip2.To16()
-	return string(ip1) < string(ip2)
-}
-
-// Sort by IP first, then by prefix length
-func sortIPNets(ipNets []net.IPNet) {
-	sort.Slice(ipNets, func(i, j int) bool {
-		ip1, ip2 := ipNets[i].IP, ipNets[j].IP
-		mask1, _ := ipNets[i].Mask.Size()
-		mask2, _ := ipNets[j].Mask.Size()
-
-		// Compare IPs first
-		if ip1.Equal(ip2) {
-			return mask1 < mask2 // If same IP, sort by subnet mask size
-		}
-		return lessIP(ip1, ip2)
-	})
-}
-
-func UniqueIPNetList(ipnets []net.IPNet) []net.IPNet {
-	uniqueMap := make(map[string]net.IPNet)
-
-	for _, ipnet := range ipnets {
-		key := ipnet.String() // Uses CIDR notation as a unique key
-		if _, exists := uniqueMap[key]; !exists {
-			uniqueMap[key] = ipnet
-		}
-	}
-
-	// Convert map back to slice
-	uniqueList := make([]net.IPNet, 0, len(uniqueMap))
-	for _, ipnet := range uniqueMap {
-		uniqueList = append(uniqueList, ipnet)
-	}
-	sortIPNets(uniqueList)
-	return uniqueList
-}
-
 func GetInetClientsFromAclPolicies(eID string) (inetClientIDs []string) {
 	e := schema.Egress{ID: eID}
 	err := e.Get(db.WithContext(context.TODO()))