Преглед на файлове

simplify egress acl rules

abhishek9686 преди 1 седмица
родител
ревизия
0cdc238a82
променени са 2 файла, в които са добавени 54 реда и са изтрити 171 реда
  1. 44 171
      logic/acls.go
  2. 10 0
      migrate/migrate.go

+ 44 - 171
logic/acls.go

@@ -558,22 +558,24 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
 	if len(egs) == 0 {
 		return
 	}
+	var egressIDMap = make(map[string]schema.Egress)
 	for _, egI := range egs {
 		if !egI.Status {
 			continue
 		}
 		if _, ok := egI.Nodes[targetnode.ID.String()]; ok {
-			targetNodeTags[models.TagID(egI.Range)] = struct{}{}
-			targetNodeTags[models.TagID(egI.ID)] = struct{}{}
+			egressIDMap[egI.ID] = egI
 		}
 	}
+	if len(egressIDMap) == 0 {
+		return
+	}
 	for _, acl := range acls {
 		if !acl.Enabled {
 			continue
 		}
 		srcTags := ConvAclTagToValueMap(acl.Src)
 		dstTags := ConvAclTagToValueMap(acl.Dst)
-		_, srcAll := srcTags["*"]
 		_, dstAll := dstTags["*"]
 		aclRule := models.AclRule{
 			ID:              acl.ID,
@@ -582,10 +584,9 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
 			Direction:       acl.AllowedDirection,
 			Allowed:         true,
 		}
-		for nodeTag := range targetNodeTags {
-
-			if nodeTag != "*" {
-				ip, cidr, err := net.ParseCIDR(nodeTag.String())
+		for egressID, egI := range egressIDMap {
+			if _, ok := dstTags[egressID]; ok || dstAll {
+				ip, cidr, err := net.ParseCIDR(egI.Range)
 				if err == nil {
 					if ip.To4() != nil {
 						aclRule.Dst = append(aclRule.Dst, *cidr)
@@ -593,134 +594,20 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
 						aclRule.Dst6 = append(aclRule.Dst6, *cidr)
 					}
 				}
-			}
-			if acl.AllowedDirection == models.TrafficDirectionBi {
-				var existsInSrcTag bool
-				var existsInDstTag bool
-				if _, ok := srcTags[nodeTag.String()]; ok || srcAll {
-					existsInSrcTag = true
-				}
-				if _, ok := dstTags[nodeTag.String()]; ok || dstAll {
-					existsInDstTag = true
-				}
-				// if srcAll || dstAll {
-				// 	if targetnode.NetworkRange.IP != nil {
-				// 		aclRule.IPList = append(aclRule.IPList, targetnode.NetworkRange)
-				// 	}
-				// 	if targetnode.NetworkRange6.IP != nil {
-				// 		aclRule.IP6List = append(aclRule.IP6List, targetnode.NetworkRange6)
-				// 	}
-				// 	break
-				// }
-				if existsInSrcTag && !existsInDstTag {
-					// get all dst tags
-					for dst := range dstTags {
-						if dst == nodeTag.String() {
-							continue
-						}
-						// Get peers in the tags and add allowed rules
-						nodes := taggedNodes[models.TagID(dst)]
-						if dst != targetnode.ID.String() {
-							node, err := GetNodeByID(dst)
-							if err == nil {
-								nodes = append(nodes, node)
-							}
-							extclient, err := GetExtClient(dst, targetnode.Network)
-							if err == nil {
-								nodes = append(nodes, extclient.ConvertToStaticNode())
-							}
-						}
-
-						for _, node := range nodes {
-							if node.ID == targetnode.ID {
-								continue
-							}
-							if node.Address.IP != nil {
-								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-							}
-							if node.Address6.IP != nil {
-								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-							}
-							if node.IsStatic && node.StaticNode.Address != "" {
-								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-							}
-							if node.IsStatic && node.StaticNode.Address6 != "" {
-								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
-							}
-						}
+				_, srcAll := srcTags["*"]
+				if srcAll {
+					if targetnode.NetworkRange.IP != nil {
+						aclRule.IPList = append(aclRule.IPList, targetnode.NetworkRange)
 					}
-				}
-				if existsInDstTag && !existsInSrcTag {
-					// get all src tags
-					for src := range srcTags {
-						if src == nodeTag.String() {
-							continue
-						}
-						// Get peers in the tags and add allowed rules
-						nodes := taggedNodes[models.TagID(src)]
-						if src != targetnode.ID.String() {
-							node, err := GetNodeByID(src)
-							if err == nil {
-								nodes = append(nodes, node)
-							} else {
-								extclient, err := GetExtClient(src, targetnode.Network)
-								if err == nil {
-									nodes = append(nodes, extclient.ConvertToStaticNode())
-								}
-							}
-
-						}
-						for _, node := range nodes {
-							if node.ID == targetnode.ID {
-								continue
-							}
-							if node.Address.IP != nil {
-								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-							}
-							if node.Address6.IP != nil {
-								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-							}
-							if node.IsStatic && node.StaticNode.Address != "" {
-								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-							}
-							if node.IsStatic && node.StaticNode.Address6 != "" {
-								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
-							}
-						}
+					if targetnode.NetworkRange6.IP != nil {
+						aclRule.IP6List = append(aclRule.IP6List, targetnode.NetworkRange6)
 					}
+					continue
 				}
-				if existsInDstTag && existsInSrcTag {
-					nodes := taggedNodes[nodeTag]
-					for srcID := range srcTags {
-						if srcID == targetnode.ID.String() {
-							continue
-						}
-						node, err := GetNodeByID(srcID)
-						if err == nil {
-							nodes = append(nodes, node)
-						} else {
-							extclient, err := GetExtClient(srcID, targetnode.Network)
-							if err == nil {
-								nodes = append(nodes, extclient.ConvertToStaticNode())
-							}
-						}
-
-					}
-					for dstID := range dstTags {
-						if dstID == targetnode.ID.String() {
-							continue
-						}
-						node, err := GetNodeByID(dstID)
-						if err == nil {
-							nodes = append(nodes, node)
-						} else {
-							extclient, err := GetExtClient(dstID, targetnode.Network)
-							if err == nil {
-								nodes = append(nodes, extclient.ConvertToStaticNode())
-							}
-						}
-
-					}
+				// get all src tags
+				for src := range srcTags {
+					// Get peers in the tags and add allowed rules
+					nodes := taggedNodes[models.TagID(src)]
 					for _, node := range nodes {
 						if node.ID == targetnode.ID {
 							continue
@@ -739,46 +626,9 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
 						}
 					}
 				}
-			} else {
-				if dstAll {
-					if targetnode.NetworkRange.IP != nil {
-						aclRule.IPList = append(aclRule.IPList, targetnode.NetworkRange)
-					}
-					if targetnode.NetworkRange6.IP != nil {
-						aclRule.IP6List = append(aclRule.IP6List, targetnode.NetworkRange6)
-					}
-					break
-				}
-				if _, ok := dstTags[nodeTag.String()]; ok || dstAll {
-					// get all src tags
-					for src := range srcTags {
-						if src == nodeTag.String() {
-							continue
-						}
-						// Get peers in the tags and add allowed rules
-						nodes := taggedNodes[models.TagID(src)]
-						for _, node := range nodes {
-							if node.ID == targetnode.ID {
-								continue
-							}
-							if node.Address.IP != nil {
-								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
-							}
-							if node.Address6.IP != nil {
-								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
-							}
-							if node.IsStatic && node.StaticNode.Address != "" {
-								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
-							}
-							if node.IsStatic && node.StaticNode.Address6 != "" {
-								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
-							}
-						}
-					}
-				}
 			}
-
 		}
+
 		if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
 			aclRule.IPList = UniqueIPNetList(aclRule.IPList)
 			aclRule.IP6List = UniqueIPNetList(aclRule.IP6List)
@@ -1820,6 +1670,7 @@ func getTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (
 	nodes, _ := GetNetworkNodes(netID.String())
 	netGwTag := models.TagID(fmt.Sprintf("%s.%s", netID.String(), models.GwTagName))
 	for _, nodeI := range nodes {
+		tagNodesMap[models.TagID(nodeI.ID.String())] = append(tagNodesMap[models.TagID(nodeI.ID.String())], nodeI)
 		if nodeI.IsGw {
 			tagNodesMap[netGwTag] = append(tagNodesMap[netGwTag], nodeI)
 		}
@@ -1828,5 +1679,27 @@ func getTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (
 	if !withStaticNodes {
 		return
 	}
-	return
+	return addTagMapWithStaticNodes(netID, tagNodesMap)
+}
+
+func addTagMapWithStaticNodes(netID models.NetworkID,
+	tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node {
+	extclients, err := GetNetworkExtClients(netID.String())
+	if err != nil {
+		return tagNodesMap
+	}
+	for _, extclient := range extclients {
+		if extclient.RemoteAccessClientID != "" {
+			continue
+		}
+		tagNodesMap[models.TagID(extclient.ClientID)] = []models.Node{
+			{
+				IsStatic:   true,
+				StaticNode: extclient,
+			},
+		}
+		tagNodesMap["*"] = append(tagNodesMap["*"], extclient.ConvertToStaticNode())
+
+	}
+	return tagNodesMap
 }

+ 10 - 0
migrate/migrate.go

@@ -671,6 +671,16 @@ func createDefaultTagsAndPolicies() {
 		logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.NetID, "all-remote-access-gws")})
 	}
 	logic.MigrateAclPolicies()
+	if !servercfg.IsPro {
+		nodes, _ := logic.GetAllNodes()
+		for _, node := range nodes {
+			if node.IsGw {
+				node.Tags = make(map[models.TagID]struct{})
+				node.Tags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{}
+				logic.UpsertNode(&node)
+			}
+		}
+	}
 }
 
 func migrateToEgressV1() {