瀏覽代碼

fetch user gw via access policy

abhishek9686 11 月之前
父節點
當前提交
3d327bb89e
共有 5 個文件被更改,包括 86 次插入8 次删除
  1. 51 0
      logic/acls.go
  2. 16 6
      logic/nodes.go
  3. 1 1
      logic/tags.go
  4. 1 1
      pro/controllers/users.go
  5. 17 0
      pro/logic/user_mgmt.go

+ 51 - 0
logic/acls.go

@@ -143,6 +143,57 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
 	return models.Acl{}, errors.New("default rule not found")
 }
 
+func ListUserPolicies(u models.User) []models.Acl {
+	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	if err != nil && !database.IsEmptyRecord(err) {
+		return []models.Acl{}
+	}
+	acls := []models.Acl{}
+	for _, dataI := range data {
+		acl := models.Acl{}
+		err := json.Unmarshal([]byte(dataI), &acl)
+		if err != nil {
+			continue
+		}
+
+		if acl.RuleType == models.UserPolicy {
+			srcMap := convAclTagToValueMap(acl.Src)
+			if _, ok := srcMap[u.UserName]; ok {
+				acls = append(acls, acl)
+			} else {
+				// check for user groups
+				for gID := range u.UserGroups {
+					if _, ok := srcMap[gID.String()]; ok {
+						acls = append(acls, acl)
+						break
+					}
+				}
+			}
+
+		}
+	}
+	return acls
+}
+
+func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl {
+	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	if err != nil && !database.IsEmptyRecord(err) {
+		return []models.Acl{}
+	}
+	acls := []models.Acl{}
+	for _, dataI := range data {
+		acl := models.Acl{}
+		err := json.Unmarshal([]byte(dataI), &acl)
+		if err != nil {
+			continue
+		}
+		if acl.NetworkID == netID && acl.RuleType == models.UserPolicy {
+			acls = append(acls, acl)
+		}
+	}
+	return acls
+}
+
 // listDevicePolicies - lists all device policies in a network
 func listDevicePolicies(netID models.NetworkID) []models.Acl {
 	data, err := database.FetchRecords(database.TAG_TABLE_NAME)

+ 16 - 6
logic/nodes.go

@@ -702,7 +702,21 @@ func GetAllFailOvers() ([]models.Node, error) {
 	return igs, nil
 }
 
-func GetTagMapWithNodes(netID models.NetworkID) (tagNodesMap map[models.TagID][]models.Node) {
+func GetTagMapWithNodes() (tagNodesMap map[models.TagID][]models.Node) {
+	tagNodesMap = make(map[models.TagID][]models.Node)
+	nodes, _ := GetAllNodes()
+	for _, nodeI := range nodes {
+		if nodeI.Tags == nil {
+			continue
+		}
+		for nodeTagID := range nodeI.Tags {
+			tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
+		}
+	}
+	return
+}
+
+func GetTagMapWithNodesByNetwork(netID models.NetworkID) (tagNodesMap map[models.TagID][]models.Node) {
 	tagNodesMap = make(map[models.TagID][]models.Node)
 	nodes, _ := GetNetworkNodes(netID.String())
 	for _, nodeI := range nodes {
@@ -710,11 +724,7 @@ func GetTagMapWithNodes(netID models.NetworkID) (tagNodesMap map[models.TagID][]
 			continue
 		}
 		for nodeTagID := range nodeI.Tags {
-			if _, ok := tagNodesMap[nodeTagID]; ok {
-				tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
-			} else {
-				tagNodesMap[nodeTagID] = []models.Node{nodeI}
-			}
+			tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
 		}
 	}
 	return

+ 1 - 1
logic/tags.go

@@ -70,7 +70,7 @@ func ListTagsWithNodes(netID models.NetworkID) ([]models.TagListResp, error) {
 	if err != nil {
 		return []models.TagListResp{}, err
 	}
-	tagsNodeMap := GetTagMapWithNodes(netID)
+	tagsNodeMap := GetTagMapWithNodesByNetwork(netID)
 	resp := []models.TagListResp{}
 	for _, tagI := range tags {
 		tagRespI := models.TagListResp{

+ 1 - 1
pro/controllers/users.go

@@ -861,7 +861,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	userGwNodes := proLogic.GetUserRAGNodes(*user)
+	userGwNodes := proLogic.GetUserRAGNodesV1(*user)
 	for _, extClient := range allextClients {
 		node, ok := userGwNodes[extClient.IngressGatewayID]
 		if !ok {

+ 17 - 0
pro/logic/user_mgmt.go

@@ -508,6 +508,23 @@ func HasNetworkRsrcScope(permissionTemplate models.UserRolePermissionTemplate, n
 	_, ok = rsrcScope[rsrcID]
 	return ok
 }
+
+func GetUserRAGNodesV1(user models.User) (gws map[string]models.Node) {
+	gws = make(map[string]models.Node)
+
+	tagNodesMap := logic.GetTagMapWithNodes()
+	accessPolices := logic.ListUserPolicies(user)
+	for _, policyI := range accessPolices {
+		for _, dstI := range policyI.Dst {
+			if nodes, ok := tagNodesMap[models.TagID(dstI.Value)]; ok {
+				for _, node := range nodes {
+					gws[node.ID.String()] = node
+				}
+			}
+		}
+	}
+	return
+}
 func GetUserRAGNodes(user models.User) (gws map[string]models.Node) {
 	gws = make(map[string]models.Node)
 	userGwAccessScope := GetUserNetworkRolesWithRemoteVPNAccess(user)