瀏覽代碼

remove uuid on id type

abhishek9686 11 月之前
父節點
當前提交
940ed8b2f0
共有 4 個文件被更改,包括 76 次插入27 次删除
  1. 25 11
      controllers/acls.go
  2. 29 13
      logic/acls.go
  3. 1 0
      migrate/migrate.go
  4. 21 3
      models/acl.go

+ 25 - 11
controllers/acls.go

@@ -7,7 +7,6 @@ import (
 	"net/url"
 	"net/url"
 	"time"
 	"time"
 
 
-	"github.com/google/uuid"
 	"github.com/gorilla/mux"
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
@@ -74,15 +73,14 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
-	// check if acl network exists
-	_, err = logic.GetNetwork(req.NetworkID.String())
+	err = logic.ValidateCreateAclReq(req)
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to get network details for "+req.NetworkID.String()), "badrequest"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	// check if acl exists
 	// check if acl exists
 	acl := req
 	acl := req
-	acl.ID = uuid.New()
+	acl.GetID(req.NetworkID, req.Name)
 	acl.CreatedBy = user.UserName
 	acl.CreatedBy = user.UserName
 	acl.CreatedAt = time.Now().UTC()
 	acl.CreatedAt = time.Now().UTC()
 	acl.Default = false
 	acl.Default = false
@@ -107,7 +105,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
 // @Success     200 {array} models.SuccessResponse
 // @Success     200 {array} models.SuccessResponse
 // @Failure     500 {object} models.ErrorResponse
 // @Failure     500 {object} models.ErrorResponse
 func updateAcl(w http.ResponseWriter, r *http.Request) {
 func updateAcl(w http.ResponseWriter, r *http.Request) {
-	var updateAcl models.Acl
+	var updateAcl models.UpdateAclRequest
 	err := json.NewDecoder(r.Body).Decode(&updateAcl)
 	err := json.NewDecoder(r.Body).Decode(&updateAcl)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, "error decoding request body: ",
 		logger.Log(0, "error decoding request body: ",
@@ -116,21 +114,37 @@ func updateAcl(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
-	acl, err := logic.GetAcl(updateAcl.ID.String())
+	acl, err := logic.GetAcl(updateAcl.Acl.ID)
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
-	if !logic.IsAclPolicyValid(updateAcl) {
+	if !logic.IsAclPolicyValid(updateAcl.Acl) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy"), "badrequest"))
 		return
 		return
 	}
 	}
-	err = logic.UpdateAcl(updateAcl, acl)
+	if updateAcl.Acl.NetworkID != acl.NetworkID {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid policy, network id mismatch"), "badrequest"))
+		return
+	}
+	if updateAcl.NewName != "" {
+		//check if policy exists with same name
+		id := models.FormatAclID(updateAcl.Acl.NetworkID, updateAcl.NewName)
+		_, err := logic.GetAcl(id)
+		if err != nil {
+			logic.ReturnErrorResponse(w, r,
+				logic.FormatError(errors.New("policy already exists with name "+updateAcl.NewName), "badrequest"))
+			return
+		}
+		updateAcl.Acl.ID = id
+		updateAcl.Acl.Name = updateAcl.NewName
+	}
+	err = logic.UpdateAcl(updateAcl.Acl, acl)
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
-	logic.ReturnSuccessResponse(w, r, "updated acl "+updateAcl.Name)
+	logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name)
 }
 }
 
 
 // @Summary     Delete Acl
 // @Summary     Delete Acl
@@ -145,7 +159,7 @@ func deleteAcl(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("acl id is required"), "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("acl id is required"), "badrequest"))
 		return
 		return
 	}
 	}
-	acl, err := logic.GetAcl(aclID)
+	acl, err := logic.GetAcl(models.AclID(aclID))
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return

+ 29 - 13
logic/acls.go

@@ -3,10 +3,10 @@ package logic
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
+	"fmt"
 	"sort"
 	"sort"
 	"time"
 	"time"
 
 
-	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 )
 )
@@ -14,9 +14,9 @@ import (
 // CreateDefaultAclNetworkPolicies - create default acl network policies
 // CreateDefaultAclNetworkPolicies - create default acl network policies
 func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
 func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
 	defaultDeviceAcl := models.Acl{
 	defaultDeviceAcl := models.Acl{
-		ID:        uuid.New(),
-		Default:   true,
+		ID:        models.AclID(fmt.Sprintf("%s.%s", netID, "all-nodes")),
 		Name:      "all-nodes",
 		Name:      "all-nodes",
+		Default:   true,
 		NetworkID: netID,
 		NetworkID: netID,
 		RuleType:  models.DevicePolicy,
 		RuleType:  models.DevicePolicy,
 		Src: []models.AclPolicyTag{
 		Src: []models.AclPolicyTag{
@@ -36,7 +36,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
 	}
 	}
 	InsertAcl(defaultDeviceAcl)
 	InsertAcl(defaultDeviceAcl)
 	defaultUserAcl := models.Acl{
 	defaultUserAcl := models.Acl{
-		ID:        uuid.New(),
+		ID:        models.AclID(fmt.Sprintf("%s.%s", netID, "all-users")),
 		Default:   true,
 		Default:   true,
 		Name:      "all-users",
 		Name:      "all-users",
 		NetworkID: netID,
 		NetworkID: netID,
@@ -73,6 +73,19 @@ func DeleteDefaultNetworkPolicies(netId models.NetworkID) {
 	}
 	}
 }
 }
 
 
+// ValidateCreateAclReq - validates create req for acl
+func ValidateCreateAclReq(req models.Acl) error {
+	// check if acl network exists
+	_, err := GetNetwork(req.NetworkID.String())
+	if err != nil {
+		return errors.New("failed to get network details for " + req.NetworkID.String())
+	}
+	if req.Name == "" {
+		return errors.New("name is required")
+	}
+	return nil
+}
+
 // InsertAcl - creates acl policy
 // InsertAcl - creates acl policy
 func InsertAcl(a models.Acl) error {
 func InsertAcl(a models.Acl) error {
 	d, err := json.Marshal(a)
 	d, err := json.Marshal(a)
@@ -83,9 +96,9 @@ func InsertAcl(a models.Acl) error {
 }
 }
 
 
 // GetAcl - gets acl info by id
 // GetAcl - gets acl info by id
-func GetAcl(aID string) (models.Acl, error) {
+func GetAcl(aID models.AclID) (models.Acl, error) {
 	a := models.Acl{}
 	a := models.Acl{}
-	d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID)
+	d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID.String())
 	if err != nil {
 	if err != nil {
 		return a, err
 		return a, err
 	}
 	}
@@ -180,13 +193,16 @@ func IsAclPolicyValid(acl models.Acl) bool {
 
 
 // UpdateAcl - updates allowed fields on acls and commits to DB
 // UpdateAcl - updates allowed fields on acls and commits to DB
 func UpdateAcl(newAcl, acl models.Acl) error {
 func UpdateAcl(newAcl, acl models.Acl) error {
-	if newAcl.Name != "" {
-		acl.Name = newAcl.Name
-	}
+
+	acl.Name = newAcl.Name
 	acl.Src = newAcl.Src
 	acl.Src = newAcl.Src
 	acl.Dst = newAcl.Dst
 	acl.Dst = newAcl.Dst
 	acl.AllowedDirection = newAcl.AllowedDirection
 	acl.AllowedDirection = newAcl.AllowedDirection
 	acl.Enabled = newAcl.Enabled
 	acl.Enabled = newAcl.Enabled
+	if acl.ID != newAcl.ID {
+		database.DeleteRecord(acl.ID.String(), database.ACLS_TABLE_NAME)
+		acl.ID = newAcl.ID
+	}
 	d, err := json.Marshal(acl)
 	d, err := json.Marshal(acl)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -212,7 +228,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
 
 
 // ListUserPolicies - lists all acl policies enforced on an user
 // ListUserPolicies - lists all acl policies enforced on an user
 func ListUserPolicies(u models.User) []models.Acl {
 func ListUserPolicies(u models.User) []models.Acl {
-	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 	if err != nil && !database.IsEmptyRecord(err) {
 		return []models.Acl{}
 		return []models.Acl{}
 	}
 	}
@@ -245,7 +261,7 @@ func ListUserPolicies(u models.User) []models.Acl {
 
 
 // ListUserPoliciesByNetwork - lists all acl user policies in a network
 // ListUserPoliciesByNetwork - lists all acl user policies in a network
 func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl {
 func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl {
-	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 	if err != nil && !database.IsEmptyRecord(err) {
 		return []models.Acl{}
 		return []models.Acl{}
 	}
 	}
@@ -265,7 +281,7 @@ func ListUserPoliciesByNetwork(netID models.NetworkID) []models.Acl {
 
 
 // listDevicePolicies - lists all device policies in a network
 // listDevicePolicies - lists all device policies in a network
 func listDevicePolicies(netID models.NetworkID) []models.Acl {
 func listDevicePolicies(netID models.NetworkID) []models.Acl {
-	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 	if err != nil && !database.IsEmptyRecord(err) {
 		return []models.Acl{}
 		return []models.Acl{}
 	}
 	}
@@ -285,7 +301,7 @@ func listDevicePolicies(netID models.NetworkID) []models.Acl {
 
 
 // ListAcls - lists all acl policies
 // ListAcls - lists all acl policies
 func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
 func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
-	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 	if err != nil && !database.IsEmptyRecord(err) {
 		return []models.Acl{}, err
 		return []models.Acl{}, err
 	}
 	}

+ 1 - 0
migrate/migrate.go

@@ -320,6 +320,7 @@ func syncUsers() {
 		if err == nil {
 		if err == nil {
 			for _, netI := range networks {
 			for _, netI := range networks {
 				logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID))
 				logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID))
+				logic.CreateDefaultAclNetworkPolicies(models.NetworkID(netI.NetID))
 				networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID)
 				networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID)
 				for _, networkNodeI := range networkNodes {
 				for _, networkNodeI := range networkNodes {
 					if networkNodeI.IsIngressGateway {
 					if networkNodeI.IsIngressGateway {

+ 21 - 3
models/acl.go

@@ -1,11 +1,24 @@
 package models
 package models
 
 
 import (
 import (
+	"fmt"
 	"time"
 	"time"
-
-	"github.com/google/uuid"
 )
 )
 
 
+type AclID string
+
+func (aID AclID) String() string {
+	return string(aID)
+}
+
+func (a *Acl) GetID(netID NetworkID, name string) {
+	a.ID = AclID(fmt.Sprintf("%s.%s", netID.String(), name))
+}
+
+func FormatAclID(netID NetworkID, name string) AclID {
+	return AclID(fmt.Sprintf("%s.%s", netID.String(), name))
+}
+
 // AllowedTrafficDirection - allowed direction of traffic
 // AllowedTrafficDirection - allowed direction of traffic
 type AllowedTrafficDirection int
 type AllowedTrafficDirection int
 
 
@@ -42,8 +55,13 @@ func (g AclGroupType) String() string {
 	return string(g)
 	return string(g)
 }
 }
 
 
+type UpdateAclRequest struct {
+	Acl     Acl
+	NewName string `json:"new_name"`
+}
+
 type Acl struct {
 type Acl struct {
-	ID               uuid.UUID               `json:"id"`
+	ID               AclID                   `json:"id"`
 	Default          bool                    `json:"default"`
 	Default          bool                    `json:"default"`
 	Name             string                  `json:"name"`
 	Name             string                  `json:"name"`
 	NetworkID        NetworkID               `json:"network_id"`
 	NetworkID        NetworkID               `json:"network_id"`