Parcourir la source

add node acl rules to fw update

abhishek9686 il y a 7 mois
Parent
commit
9ba818ffa9
1 fichiers modifiés avec 88 ajouts et 36 suppressions
  1. 88 36
      logic/acls.go

+ 88 - 36
logic/acls.go

@@ -567,6 +567,10 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []mode
 			allowedPolicies = append(allowedPolicies, policy)
 			continue
 		}
+		if _, ok := dstMap[peer.ID.String()]; ok {
+			allowedPolicies = append(allowedPolicies, policy)
+			continue
+		}
 		for tagID := range peer.Tags {
 			if _, ok := dstMap[tagID.String()]; ok {
 				allowedPolicies = append(allowedPolicies, policy)
@@ -623,10 +627,15 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.Node) bool {
 	// check for node ID
 	if _, ok := srcMap[node.ID.String()]; ok {
-		return true
+		if _, ok = dstMap[peer.ID.String()]; ok {
+			return true
+		}
+
 	}
 	if _, ok := dstMap[node.ID.String()]; ok {
-		return true
+		if _, ok = srcMap[peer.ID.String()]; ok {
+			return true
+		}
 	}
 	for tagID := range node.Tags {
 		if _, ok := dstMap[tagID.String()]; ok {
@@ -892,30 +901,34 @@ func getUserAclRulesForNode(targetnode *models.Node,
 	userGrpMap := GetUserGrpMap()
 	allowedUsers := make(map[string][]models.Acl)
 	acls := listUserPolicies(models.NetworkID(targetnode.Network))
-	for nodeTag := range targetnode.Tags {
-		for _, acl := range acls {
-			if !acl.Enabled {
-				continue
+
+	for _, acl := range acls {
+		if !acl.Enabled {
+			continue
+		}
+		dstTags := convAclTagToValueMap(acl.Dst)
+		for nodeTag := range targetnode.Tags {
+			if _, ok := dstTags[nodeTag.String()]; !ok {
+				if _, ok = dstTags[targetnode.ID.String()]; !ok {
+					continue
+				}
 			}
-			dstTags := convAclTagToValueMap(acl.Dst)
-			if _, ok := dstTags[nodeTag.String()]; ok {
-				// get all src tags
-				for _, srcAcl := range acl.Src {
-					if srcAcl.ID == models.UserAclID {
-						allowedUsers[srcAcl.Value] = append(allowedUsers[srcAcl.Value], acl)
-					} else if srcAcl.ID == models.UserGroupAclID {
-						// fetch all users in the group
-						if usersMap, ok := userGrpMap[models.UserGroupID(srcAcl.Value)]; ok {
-							for userName := range usersMap {
-								allowedUsers[userName] = append(allowedUsers[userName], acl)
-							}
+			// get all src tags
+			for _, srcAcl := range acl.Src {
+				if srcAcl.ID == models.UserAclID {
+					allowedUsers[srcAcl.Value] = append(allowedUsers[srcAcl.Value], acl)
+				} else if srcAcl.ID == models.UserGroupAclID {
+					// fetch all users in the group
+					if usersMap, ok := userGrpMap[models.UserGroupID(srcAcl.Value)]; ok {
+						for userName := range usersMap {
+							allowedUsers[userName] = append(allowedUsers[userName], acl)
 						}
 					}
 				}
-
 			}
 		}
 	}
+
 	for _, userNode := range userNodes {
 		if !userNode.StaticNode.Enabled {
 			continue
@@ -973,20 +986,21 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 
 	acls := listDevicePolicies(models.NetworkID(targetnode.Network))
 	targetnode.Tags["*"] = struct{}{}
-	for nodeTag := range targetnode.Tags {
-		for _, acl := range acls {
-			if !acl.Enabled {
-				continue
-			}
-			srcTags := convAclTagToValueMap(acl.Src)
-			dstTags := convAclTagToValueMap(acl.Dst)
-			aclRule := models.AclRule{
-				ID:              acl.ID,
-				AllowedProtocol: acl.Proto,
-				AllowedPorts:    acl.Port,
-				Direction:       acl.AllowedDirection,
-				Allowed:         true,
-			}
+
+	for _, acl := range acls {
+		if !acl.Enabled {
+			continue
+		}
+		srcTags := convAclTagToValueMap(acl.Src)
+		dstTags := convAclTagToValueMap(acl.Dst)
+		aclRule := models.AclRule{
+			ID:              acl.ID,
+			AllowedProtocol: acl.Proto,
+			AllowedPorts:    acl.Port,
+			Direction:       acl.AllowedDirection,
+			Allowed:         true,
+		}
+		for nodeTag := range targetnode.Tags {
 			if acl.AllowedDirection == models.TrafficDirectionBi {
 				var existsInSrcTag bool
 				var existsInDstTag bool
@@ -994,9 +1008,15 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 				if _, ok := srcTags[nodeTag.String()]; ok {
 					existsInSrcTag = true
 				}
+				if _, ok := srcTags[targetnode.ID.String()]; ok {
+					existsInSrcTag = true
+				}
 				if _, ok := dstTags[nodeTag.String()]; ok {
 					existsInDstTag = true
 				}
+				if _, ok := dstTags[targetnode.ID.String()]; ok {
+					existsInDstTag = true
+				}
 
 				if existsInSrcTag && !existsInDstTag {
 					// get all dst tags
@@ -1006,6 +1026,13 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 						}
 						// Get peers in the tags and add allowed rules
 						nodes := taggedNodes[models.TagID(dst)]
+						if dst != targetnode.ID.String() {
+							node, err := GetNodeByID(dst)
+							if err == nil {
+								nodes = append(nodes, node)
+							}
+						}
+
 						for _, node := range nodes {
 							if node.ID == targetnode.ID {
 								continue
@@ -1033,6 +1060,12 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 						}
 						// Get peers in the tags and add allowed rules
 						nodes := taggedNodes[models.TagID(src)]
+						if src != targetnode.ID.String() {
+							node, err := GetNodeByID(src)
+							if err == nil {
+								nodes = append(nodes, node)
+							}
+						}
 						for _, node := range nodes {
 							if node.ID == targetnode.ID {
 								continue
@@ -1054,6 +1087,24 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 				}
 				if existsInDstTag && existsInSrcTag {
 					nodes := taggedNodes[nodeTag]
+					for srcID := range srcTags {
+						if srcID == targetnode.ID.String() {
+							continue
+						}
+						node, err := GetNodeByID(srcID)
+						if err == nil {
+							nodes = append(nodes, node)
+						}
+					}
+					for dstID := range dstTags {
+						if dstID == targetnode.ID.String() {
+							continue
+						}
+						node, err := GetNodeByID(dstID)
+						if err == nil {
+							nodes = append(nodes, node)
+						}
+					}
 					for _, node := range nodes {
 						if node.ID == targetnode.ID {
 							continue
@@ -1102,9 +1153,10 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 					}
 				}
 			}
-			if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
-				rules[acl.ID] = aclRule
-			}
+
+		}
+		if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
+			rules[acl.ID] = aclRule
 		}
 	}
 	return rules