Переглянути джерело

merged ce and pro acl rule func

abhishek9686 1 тиждень тому
батько
коміт
818a2e806e
3 змінених файлів з 98 додано та 251 видалено
  1. 96 44
      logic/acls.go
  2. 1 1
      pro/initialize.go
  3. 1 206
      pro/logic/acls.go

+ 96 - 44
logic/acls.go

@@ -26,6 +26,10 @@ var GetEgressUserRulesForNode = func(targetnode *models.Node,
 	rules map[string]models.AclRule) map[string]models.AclRule {
 	return rules
 }
+var GetUserAclRulesForNode = func(targetnode *models.Node,
+	rules map[string]models.AclRule) map[string]models.AclRule {
+	return rules
+}
 
 var GetFwRulesForUserNodesOnGw = func(node models.Node, nodes []models.Node) (rules []models.FwRule) { return }
 
@@ -382,9 +386,13 @@ var CheckIfAnyPolicyisUniDirectional = func(targetNode models.Node, acls []model
 	return false
 }
 
-var GetAclRulesForNode = func(targetnodeI *models.Node) (rules map[string]models.AclRule) {
+func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRule) {
 	targetnode := *targetnodeI
-
+	defer func() {
+		//if !targetnode.IsIngressGateway {
+		rules = GetUserAclRulesForNode(&targetnode, rules)
+		//}
+	}()
 	rules = make(map[string]models.AclRule)
 	if IsNodeAllowedToCommunicateWithAllRsrcs(targetnode) {
 		aclRule := models.AclRule{
@@ -400,20 +408,32 @@ var GetAclRulesForNode = func(targetnodeI *models.Node) (rules map[string]models
 		rules[aclRule.ID] = aclRule
 		return
 	}
+	var taggedNodes map[models.TagID][]models.Node
+	if targetnode.IsIngressGateway {
+		taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false)
+	} else {
+		taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true)
+	}
 	acls := ListDevicePolicies(models.NetworkID(targetnode.Network))
-	targetNodeTags := make(map[models.TagID]struct{})
+	var targetNodeTags = make(map[models.TagID]struct{})
+	if targetnode.Mutex != nil {
+		targetnode.Mutex.Lock()
+		targetNodeTags = maps.Clone(targetnode.Tags)
+		targetnode.Mutex.Unlock()
+	} else {
+		targetNodeTags = maps.Clone(targetnode.Tags)
+	}
+	if targetNodeTags == nil {
+		targetNodeTags = make(map[models.TagID]struct{})
+	}
 	targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{}
 	targetNodeTags["*"] = struct{}{}
-	if targetnodeI.IsGw && !servercfg.IsPro {
-		targetNodeTags[models.TagID(fmt.Sprintf("%s.%s", targetnodeI.Network, models.GwTagName))] = struct{}{}
-	}
 	for _, acl := range acls {
 		if !acl.Enabled {
 			continue
 		}
 		srcTags := ConvAclTagToValueMap(acl.Src)
 		dstTags := ConvAclTagToValueMap(acl.Dst)
-		nodes := []models.Node{}
 		for _, dst := range acl.Dst {
 			if dst.ID == models.EgressID {
 				e := schema.Egress{ID: dst.Value}
@@ -459,35 +479,35 @@ var GetAclRulesForNode = func(targetnodeI *models.Node) (rules map[string]models
 							continue
 						}
 						// Get peers in the tags and add allowed rules
+						nodes := taggedNodes[models.TagID(dst)]
 						if dst != targetnode.ID.String() {
 							node, err := GetNodeByID(dst)
 							if err == nil {
 								nodes = append(nodes, node)
 							}
 						}
-					}
 
-					for _, node := range nodes {
-						if node.ID == targetnode.ID {
-							continue
-						}
-						if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
-							continue
-						}
-						if node.Address.IP != nil {
-							aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-						}
-						if node.Address6.IP != nil {
-							aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-						}
-						if node.IsStatic && node.StaticNode.Address != "" {
-							aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-						}
-						if node.IsStatic && node.StaticNode.Address6 != "" {
-							aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+						for _, node := range nodes {
+							if node.ID == targetnode.ID {
+								continue
+							}
+							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+								continue
+							}
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
+							if node.IsStatic && node.StaticNode.Address != "" {
+								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+							}
+							if node.IsStatic && node.StaticNode.Address6 != "" {
+								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+							}
 						}
 					}
-
 				}
 				if existsInDstTag /*&& !existsInSrcTag*/ {
 					// get all src tags
@@ -496,34 +516,66 @@ var GetAclRulesForNode = func(targetnodeI *models.Node) (rules map[string]models
 							continue
 						}
 						// Get peers in the tags and add allowed rules
+						nodes := taggedNodes[models.TagID(src)]
 						if src != targetnode.ID.String() {
 							node, err := GetNodeByID(src)
 							if err == nil {
 								nodes = append(nodes, node)
 							}
 						}
-					}
-					for _, node := range nodes {
-						if node.ID == targetnode.ID {
-							continue
+						for _, node := range nodes {
+							if node.ID == targetnode.ID {
+								continue
+							}
+							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+								continue
+							}
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
+							if node.IsStatic && node.StaticNode.Address != "" {
+								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+							}
+							if node.IsStatic && node.StaticNode.Address6 != "" {
+								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+							}
 						}
-						if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+					}
+				}
+			} else {
+				_, all := dstTags["*"]
+				if _, ok := dstTags[nodeTag.String()]; ok || all {
+					// get all src tags
+					for src := range srcTags {
+						if src == nodeTag.String() {
 							continue
 						}
-						if node.Address.IP != nil {
-							aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-						}
-						if node.Address6.IP != nil {
-							aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-						}
-						if node.IsStatic && node.StaticNode.Address != "" {
-							aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-						}
-						if node.IsStatic && node.StaticNode.Address6 != "" {
-							aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+						// Get peers in the tags and add allowed rules
+						nodes := taggedNodes[models.TagID(src)]
+						for _, node := range nodes {
+							if node.ID == targetnode.ID {
+								continue
+							}
+							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+								continue
+							}
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
+							if node.IsStatic && node.StaticNode.Address != "" {
+								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+							}
+							if node.IsStatic && node.StaticNode.Address6 != "" {
+								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+							}
 						}
 					}
-
 				}
 			}
 

+ 1 - 1
pro/initialize.go

@@ -149,7 +149,7 @@ func InitPro() {
 	logic.IsAclPolicyValid = proLogic.IsAclPolicyValid
 	logic.GetEgressUserRulesForNode = proLogic.GetEgressUserRulesForNode
 	logic.GetTagMapWithNodesByNetwork = proLogic.GetTagMapWithNodesByNetwork
-	logic.GetAclRulesForNode = proLogic.GetAclRulesForNode
+	logic.GetUserAclRulesForNode = proLogic.GetUserAclRulesForNode
 	logic.CheckIfAnyPolicyisUniDirectional = proLogic.CheckIfAnyPolicyisUniDirectional
 	logic.MigrateToGws = proLogic.MigrateToGws
 	logic.GetFwRulesForNodeAndPeerOnGw = proLogic.GetFwRulesForNodeAndPeerOnGw

+ 1 - 206
pro/logic/acls.go

@@ -3,7 +3,6 @@ package logic
 import (
 	"context"
 	"errors"
-	"fmt"
 	"maps"
 	"net"
 
@@ -926,7 +925,7 @@ func GetEgressUserRulesForNode(targetnode *models.Node,
 	return rules
 }
 
-func getUserAclRulesForNode(targetnode *models.Node,
+func GetUserAclRulesForNode(targetnode *models.Node,
 	rules map[string]models.AclRule) map[string]models.AclRule {
 	userNodes := logic.GetStaticUserNodesByNetwork(models.NetworkID(targetnode.Network))
 	userGrpMap := GetUserGrpMap()
@@ -1096,210 +1095,6 @@ func CheckIfAnyPolicyisUniDirectional(targetNode models.Node, acls []models.Acl)
 	return false
 }
 
-func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRule) {
-	targetnode := *targetnodeI
-	defer func() {
-		//if !targetnode.IsIngressGateway {
-		rules = getUserAclRulesForNode(&targetnode, rules)
-		//}
-	}()
-	rules = make(map[string]models.AclRule)
-	if logic.IsNodeAllowedToCommunicateWithAllRsrcs(targetnode) {
-		aclRule := models.AclRule{
-			ID:              fmt.Sprintf("%s-all-allowed-node-rule", targetnode.ID.String()),
-			AllowedProtocol: models.ALL,
-			Direction:       models.TrafficDirectionBi,
-			Allowed:         true,
-			IPList:          []net.IPNet{targetnode.NetworkRange},
-			IP6List:         []net.IPNet{targetnode.NetworkRange6},
-			Dst:             []net.IPNet{targetnode.Address},
-			Dst6:            []net.IPNet{targetnode.Address6},
-		}
-		rules[aclRule.ID] = aclRule
-		return
-	}
-	var taggedNodes map[models.TagID][]models.Node
-	if targetnode.IsIngressGateway {
-		taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false)
-	} else {
-		taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true)
-	}
-	acls := logic.ListDevicePolicies(models.NetworkID(targetnode.Network))
-	var targetNodeTags = make(map[models.TagID]struct{})
-	if targetnode.Mutex != nil {
-		targetnode.Mutex.Lock()
-		targetNodeTags = maps.Clone(targetnode.Tags)
-		targetnode.Mutex.Unlock()
-	} else {
-		targetNodeTags = maps.Clone(targetnode.Tags)
-	}
-	if targetNodeTags == nil {
-		targetNodeTags = make(map[models.TagID]struct{})
-	}
-	targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{}
-	targetNodeTags["*"] = struct{}{}
-	for _, acl := range acls {
-		if !acl.Enabled {
-			continue
-		}
-		srcTags := logic.ConvAclTagToValueMap(acl.Src)
-		dstTags := logic.ConvAclTagToValueMap(acl.Dst)
-		for _, dst := range acl.Dst {
-			if dst.ID == models.EgressID {
-				e := schema.Egress{ID: dst.Value}
-				err := e.Get(db.WithContext(context.TODO()))
-				if err == nil && e.Status {
-					for nodeID := range e.Nodes {
-						dstTags[nodeID] = struct{}{}
-					}
-				}
-			}
-		}
-		_, srcAll := srcTags["*"]
-		_, dstAll := dstTags["*"]
-		aclRule := models.AclRule{
-			ID:              acl.ID,
-			AllowedProtocol: acl.Proto,
-			AllowedPorts:    acl.Port,
-			Direction:       acl.AllowedDirection,
-			Allowed:         true,
-		}
-		for nodeTag := range targetNodeTags {
-			if acl.AllowedDirection == models.TrafficDirectionBi {
-				var existsInSrcTag bool
-				var existsInDstTag bool
-
-				if _, ok := srcTags[nodeTag.String()]; ok || srcAll {
-					existsInSrcTag = true
-				}
-				if _, ok := srcTags[targetnode.ID.String()]; ok || srcAll {
-					existsInSrcTag = true
-				}
-				if _, ok := dstTags[nodeTag.String()]; ok || dstAll {
-					existsInDstTag = true
-				}
-				if _, ok := dstTags[targetnode.ID.String()]; ok || dstAll {
-					existsInDstTag = true
-				}
-
-				if existsInSrcTag /* && !existsInDstTag*/ {
-					// get all dst tags
-					for dst := range dstTags {
-						if dst == nodeTag.String() {
-							continue
-						}
-						// Get peers in the tags and add allowed rules
-						nodes := taggedNodes[models.TagID(dst)]
-						if dst != targetnode.ID.String() {
-							node, err := logic.GetNodeByID(dst)
-							if err == nil {
-								nodes = append(nodes, node)
-							}
-						}
-
-						for _, node := range nodes {
-							if node.ID == targetnode.ID {
-								continue
-							}
-							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
-								continue
-							}
-							if node.Address.IP != nil {
-								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-							}
-							if node.Address6.IP != nil {
-								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-							}
-							if node.IsStatic && node.StaticNode.Address != "" {
-								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-							}
-							if node.IsStatic && node.StaticNode.Address6 != "" {
-								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
-							}
-						}
-					}
-				}
-				if existsInDstTag /*&& !existsInSrcTag*/ {
-					// get all src tags
-					for src := range srcTags {
-						if src == nodeTag.String() {
-							continue
-						}
-						// Get peers in the tags and add allowed rules
-						nodes := taggedNodes[models.TagID(src)]
-						if src != targetnode.ID.String() {
-							node, err := logic.GetNodeByID(src)
-							if err == nil {
-								nodes = append(nodes, node)
-							}
-						}
-						for _, node := range nodes {
-							if node.ID == targetnode.ID {
-								continue
-							}
-							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
-								continue
-							}
-							if node.Address.IP != nil {
-								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-							}
-							if node.Address6.IP != nil {
-								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-							}
-							if node.IsStatic && node.StaticNode.Address != "" {
-								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-							}
-							if node.IsStatic && node.StaticNode.Address6 != "" {
-								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
-							}
-						}
-					}
-				}
-			} else {
-				_, all := dstTags["*"]
-				if _, ok := dstTags[nodeTag.String()]; ok || all {
-					// get all src tags
-					for src := range srcTags {
-						if src == nodeTag.String() {
-							continue
-						}
-						// Get peers in the tags and add allowed rules
-						nodes := taggedNodes[models.TagID(src)]
-						for _, node := range nodes {
-							if node.ID == targetnode.ID {
-								continue
-							}
-							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
-								continue
-							}
-							if node.Address.IP != nil {
-								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-							}
-							if node.Address6.IP != nil {
-								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-							}
-							if node.IsStatic && node.StaticNode.Address != "" {
-								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-							}
-							if node.IsStatic && node.StaticNode.Address6 != "" {
-								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
-							}
-						}
-					}
-				}
-			}
-
-		}
-
-		if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
-			aclRule.IPList = logic.UniqueIPNetList(aclRule.IPList)
-			aclRule.IP6List = logic.UniqueIPNetList(aclRule.IP6List)
-			rules[acl.ID] = aclRule
-		}
-	}
-	return rules
-}
-
 func GetTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) {
 	tagNodesMap = make(map[models.TagID][]models.Node)
 	nodes, _ := logic.GetNetworkNodes(netID.String())