Răsfoiți Sursa

add acl policy checker

abhishek9686 11 luni în urmă
părinte
comite
630928b4f7
3 a modificat fișierele cu 74 adăugiri și 1 ștergeri
  1. 9 1
      controllers/acls.go
  2. 54 0
      logic/acls.go
  3. 11 0
      models/acl.go

+ 9 - 1
controllers/acls.go

@@ -85,7 +85,11 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
 	acl.ID = uuid.New()
 	acl.CreatedBy = user.UserName
 	acl.CreatedAt = time.Now().UTC()
-
+	// validate create acl policy
+	if !logic.IsAclPolicyValid(acl) {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
+		return
+	}
 	err = logic.InsertAcl(acl)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
@@ -116,6 +120,10 @@ func updateAcl(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
+	if !logic.IsAclPolicyValid(updateAcl) {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
+		return
+	}
 	err = logic.UpdateAcl(updateAcl, acl)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))

+ 54 - 0
logic/acls.go

@@ -3,6 +3,7 @@ package logic
 import (
 	"encoding/json"
 	"sort"
+	"strings"
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
@@ -30,6 +31,59 @@ func GetAcl(aID string) (models.Acl, error) {
 	return a, nil
 }
 
+func IsAclPolicyValid(acl models.Acl) bool {
+	//check if src and dst are valid
+	isValid := false
+	switch acl.RuleType {
+	case models.UserPolicy:
+		// src list should only contain users
+		for _, srcI := range acl.Src {
+			userTagLi := strings.Split(srcI, ":")
+			if len(userTagLi) < 2 {
+				break
+			}
+			if userTagLi[0] != models.UserAcl.String() &&
+				userTagLi[0] != models.UserGroupAcl.String() {
+				break
+			}
+		}
+		for _, dstI := range acl.Dst {
+			dstILi := strings.Split(dstI, ":")
+			if len(dstILi) < 2 {
+				break
+			}
+			if dstILi[0] == models.UserAcl.String() ||
+				dstILi[0] == models.UserGroupAcl.String() {
+				break
+			}
+		}
+		isValid = true
+	case models.DevicePolicy:
+		for _, srcI := range acl.Src {
+			userTagLi := strings.Split(srcI, ":")
+			if len(userTagLi) < 2 {
+				break
+			}
+			if userTagLi[0] == models.UserAcl.String() ||
+				userTagLi[0] == models.UserGroupAcl.String() {
+				break
+			}
+		}
+		for _, dstI := range acl.Dst {
+			dstILi := strings.Split(dstI, ":")
+			if len(dstILi) < 2 {
+				break
+			}
+			if dstILi[0] == models.UserAcl.String() ||
+				dstILi[0] == models.UserGroupAcl.String() {
+				break
+			}
+		}
+		isValid = true
+	}
+	return isValid
+}
+
 // UpdateAcl - updates allowed fields on acls and commits to DB
 func UpdateAcl(newAcl, acl models.Acl) error {
 	if newAcl.Name != "" {

+ 11 - 0
models/acl.go

@@ -23,6 +23,17 @@ const (
 	DevicePolicy AclPolicyType = "device-policy"
 )
 
+type AclGroupType string
+
+const (
+	UserAcl      AclGroupType = "user"
+	UserGroupAcl AclGroupType = "user-group"
+)
+
+func (g AclGroupType) String() string {
+	return string(g)
+}
+
 type Acl struct {
 	ID               uuid.UUID               `json:"id"`
 	Name             string                  `json:"name"`