Ver código fonte

add validation check for tags

abhishek9686 11 meses atrás
pai
commit
00b082d11c
5 arquivos alterados com 50 adições e 14 exclusões
  1. 31 12
      logic/acls.go
  2. 3 0
      logic/user_mgmt.go
  3. 5 2
      models/acl.go
  4. 1 0
      pro/initialize.go
  5. 10 0
      pro/logic/user_mgmt.go

+ 31 - 12
logic/acls.go

@@ -42,40 +42,59 @@ func IsAclPolicyValid(acl models.Acl) bool {
 			if len(userTagLi) < 2 {
 				break
 			}
-			if userTagLi[0] != models.UserAcl.String() &&
-				userTagLi[0] != models.UserGroupAcl.String() {
+			if userTagLi[0] != models.UserAclID.String() &&
+				userTagLi[0] != models.UserGroupAclID.String() {
 				break
 			}
+			// check if user group is valid
+			if userTagLi[0] == models.UserAclID.String() {
+				_, err := GetUser(userTagLi[1])
+				if err != nil {
+					break
+				}
+			} else if userTagLi[0] == models.UserGroupAclID.String() {
+				err := IsGroupValid(models.UserGroupID(userTagLi[1]))
+				if err != nil {
+					break
+				}
+			}
+
 		}
 		for _, dstI := range acl.Dst {
 			dstILi := strings.Split(dstI, ":")
 			if len(dstILi) < 2 {
 				break
 			}
-			if dstILi[0] == models.UserAcl.String() ||
-				dstILi[0] == models.UserGroupAcl.String() {
+			if dstILi[0] == models.UserAclID.String() ||
+				dstILi[0] == models.UserGroupAclID.String() {
+				break
+			}
+			if dstILi[0] != models.DeviceAclID.String() {
+				break
+			}
+			// check if tag is valid
+			_, err := GetTag(models.TagID(dstILi[1]))
+			if err != nil {
 				break
 			}
 		}
 		isValid = true
 	case models.DevicePolicy:
 		for _, srcI := range acl.Src {
-			userTagLi := strings.Split(srcI, ":")
-			if len(userTagLi) < 2 {
+			deviceTagLi := strings.Split(srcI, ":")
+			if len(deviceTagLi) < 2 {
 				break
 			}
-			if userTagLi[0] == models.UserAcl.String() ||
-				userTagLi[0] == models.UserGroupAcl.String() {
+			if deviceTagLi[0] != models.DeviceAclID.String() {
 				break
 			}
 		}
 		for _, dstI := range acl.Dst {
-			dstILi := strings.Split(dstI, ":")
-			if len(dstILi) < 2 {
+			deviceTagLi := strings.Split(dstI, ":")
+			if len(deviceTagLi) < 2 {
 				break
 			}
-			if dstILi[0] == models.UserAcl.String() ||
-				dstILi[0] == models.UserGroupAcl.String() {
+			if deviceTagLi[0] != models.DeviceAclID.String() {
 				break
 			}
 		}

+ 3 - 0
logic/user_mgmt.go

@@ -39,6 +39,9 @@ var FilterNetworksByRole = func(allnetworks []models.Network, user models.User)
 var IsGroupsValid = func(groups map[models.UserGroupID]struct{}) error {
 	return nil
 }
+var IsGroupValid = func(groupID models.UserGroupID) error {
+	return nil
+}
 var IsNetworkRolesValid = func(networkRoles map[models.NetworkID]map[models.UserRoleID]struct{}) error {
 	return nil
 }

+ 5 - 2
models/acl.go

@@ -26,8 +26,11 @@ const (
 type AclGroupType string
 
 const (
-	UserAcl      AclGroupType = "user"
-	UserGroupAcl AclGroupType = "user-group"
+	UserAclID                AclGroupType = "user"
+	UserGroupAclID           AclGroupType = "user-group"
+	DeviceAclID              AclGroupType = "tag"
+	NetmakerIPAclID          AclGroupType = "ip"
+	NetmakerSubNetRangeAClID AclGroupType = "ipset"
 )
 
 func (g AclGroupType) String() string {

+ 1 - 0
pro/initialize.go

@@ -130,6 +130,7 @@ func InitPro() {
 	logic.CreateDefaultNetworkRolesAndGroups = proLogic.CreateDefaultNetworkRolesAndGroups
 	logic.FilterNetworksByRole = proLogic.FilterNetworksByRole
 	logic.IsGroupsValid = proLogic.IsGroupsValid
+	logic.IsGroupValid = proLogic.IsGroupValid
 	logic.IsNetworkRolesValid = proLogic.IsNetworkRolesValid
 	logic.InitialiseRoles = proLogic.UserRolesInit
 	logic.UpdateUserGwAccess = proLogic.UpdateUserGwAccess

+ 10 - 0
pro/logic/user_mgmt.go

@@ -789,6 +789,16 @@ func IsGroupsValid(groups map[models.UserGroupID]struct{}) error {
 	return nil
 }
 
+func IsGroupValid(groupID models.UserGroupID) error {
+
+	_, err := GetUserGroup(groupID)
+	if err != nil {
+		return fmt.Errorf("user group `%s` not found", groupID)
+	}
+
+	return nil
+}
+
 func IsNetworkRolesValid(networkRoles map[models.NetworkID]map[models.UserRoleID]struct{}) error {
 	for netID, netRoles := range networkRoles {