소스 검색

add error messaging for acl policy check

abhishek9686 4 달 전
부모
커밋
d1fc2eebb2
2개의 변경된 파일31개의 추가작업 그리고 31개의 파일을 삭제
  1. 2 2
      controllers/acls.go
  2. 29 29
      logic/acls.go

+ 2 - 2
controllers/acls.go

@@ -253,7 +253,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
 		acl.Proto = models.ALL
 	}
 	// validate create acl policy
-	if !logic.IsAclPolicyValid(acl) {
+	if err := logic.IsAclPolicyValid(acl); err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
 		return
 	}
@@ -292,7 +292,7 @@ func updateAcl(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	if !logic.IsAclPolicyValid(updateAcl.Acl) {
+	if err := logic.IsAclPolicyValid(updateAcl.Acl); err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
 		return
 	}

+ 29 - 29
logic/acls.go

@@ -271,26 +271,26 @@ func GetEgressRanges(netID models.NetworkID) (map[string][]string, map[string]st
 	return nodeEgressMap, resultMap, nil
 }
 
-func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) bool {
+func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err error) {
 	switch t.ID {
 	case models.NodeTagID:
 		if a.RuleType == models.UserPolicy && isSrc {
-			return false
+			return errors.New("user policy source mismatch")
 		}
 		// check if tag is valid
 		_, err := GetTag(models.TagID(t.Value))
 		if err != nil {
-			return false
+			return errors.New("invalid tag " + t.Value)
 		}
 	case models.NodeID:
 		if a.RuleType == models.UserPolicy && isSrc {
-			return false
+			return errors.New("user policy source mismatch")
 		}
 		_, nodeErr := GetNodeByID(t.Value)
 		if nodeErr != nil {
 			_, staticNodeErr := GetExtClient(t.Value, a.NetworkID.String())
 			if staticNodeErr != nil {
-				return false
+				return errors.New("invalid node " + t.Value)
 			}
 		}
 	case models.EgressID, models.EgressRange:
@@ -299,7 +299,7 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) bool
 		}
 		err := e.Get(db.WithContext(context.TODO()))
 		if err != nil {
-			return false
+			return errors.New("invalid egress")
 		}
 		if e.IsInetGw {
 			req := models.InetNodeReq{}
@@ -317,10 +317,10 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) bool
 				for k := range e.Nodes {
 					inetNode, err := GetNodeByID(k)
 					if err != nil {
-						return false
+						return errors.New("invalid node " + t.Value)
 					}
-					if ValidateInetGwReq(inetNode, req, false) != nil {
-						return false
+					if err = ValidateInetGwReq(inetNode, req, false); err != nil {
+						return err
 					}
 				}
 
@@ -330,43 +330,43 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) bool
 
 	case models.UserAclID:
 		if a.RuleType == models.DevicePolicy {
-			return false
+			return errors.New("device policy source mismatch")
 		}
 		if !isSrc {
-			return false
+			return errors.New("user cannot be added to destination")
 		}
 		_, err := GetUser(t.Value)
 		if err != nil {
-			return false
+			return errors.New("invalid user " + t.Value)
 		}
 	case models.UserGroupAclID:
 		if a.RuleType == models.DevicePolicy {
-			return false
+			return errors.New("device policy source mismatch")
 		}
 		if !isSrc {
-			return false
+			return errors.New("user cannot be added to destination")
 		}
 		err := IsGroupValid(models.UserGroupID(t.Value))
 		if err != nil {
-			return false
+			return errors.New("invalid user group " + t.Value)
 		}
 		// check if group belongs to this network
 		netGrps := GetUserGroupsInNetwork(a.NetworkID)
 		if _, ok := netGrps[models.UserGroupID(t.Value)]; !ok {
-			return false
+			return errors.New("invalid user group " + t.Value)
 		}
 	default:
-		return false
+		return errors.New("invalid policy")
 	}
-	return true
+	return nil
 }
 
 // IsAclPolicyValid - validates if acl policy is valid
-func IsAclPolicyValid(acl models.Acl) bool {
+func IsAclPolicyValid(acl models.Acl) (err error) {
 	//check if src and dst are valid
 	if acl.AllowedDirection != models.TrafficDirectionBi &&
 		acl.AllowedDirection != models.TrafficDirectionUni {
-		return false
+		return errors.New("invalid traffic direction")
 	}
 	switch acl.RuleType {
 	case models.UserPolicy:
@@ -377,8 +377,8 @@ func IsAclPolicyValid(acl models.Acl) bool {
 				continue
 			}
 			// check if user group is valid
-			if !checkIfAclTagisValid(acl, srcI, true) {
-				return false
+			if err = checkIfAclTagisValid(acl, srcI, true); err != nil {
+				return
 			}
 		}
 		for _, dstI := range acl.Dst {
@@ -388,8 +388,8 @@ func IsAclPolicyValid(acl models.Acl) bool {
 			}
 
 			// check if user group is valid
-			if !checkIfAclTagisValid(acl, dstI, false) {
-				return false
+			if err = checkIfAclTagisValid(acl, dstI, false); err != nil {
+				return
 			}
 		}
 	case models.DevicePolicy:
@@ -398,8 +398,8 @@ func IsAclPolicyValid(acl models.Acl) bool {
 				continue
 			}
 			// check if user group is valid
-			if !checkIfAclTagisValid(acl, srcI, true) {
-				return false
+			if err = checkIfAclTagisValid(acl, srcI, true); err != nil {
+				return err
 			}
 		}
 		for _, dstI := range acl.Dst {
@@ -408,12 +408,12 @@ func IsAclPolicyValid(acl models.Acl) bool {
 				continue
 			}
 			// check if user group is valid
-			if !checkIfAclTagisValid(acl, dstI, false) {
-				return false
+			if err = checkIfAclTagisValid(acl, dstI, false); err != nil {
+				return
 			}
 		}
 	}
-	return true
+	return nil
 }
 
 func UniqueAclPolicyTags(tags []models.AclPolicyTag) []models.AclPolicyTag {