浏览代码

Merge pull request #3184 from gravitl/NET-1767-acls

NET-1767: cache acls v1
Abhishek K 10 月之前
父节点
当前提交
178157cf1c
共有 1 个文件被更改,包括 105 次插入52 次删除
  1. 105 52
      logic/acls.go

+ 105 - 52
logic/acls.go

@@ -5,10 +5,17 @@ import (
 	"errors"
 	"fmt"
 	"sort"
+	"sync"
 	"time"
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
+)
+
+var (
+	aclCacheMutex = &sync.RWMutex{}
+	aclCacheMap   = make(map[string]models.Acl)
 )
 
 // CreateDefaultAclNetworkPolicies - create default acl network policies
@@ -120,18 +127,57 @@ func ValidateCreateAclReq(req models.Acl) error {
 	return nil
 }
 
+func listAclFromCache() (acls []models.Acl) {
+	aclCacheMutex.RLock()
+	defer aclCacheMutex.RUnlock()
+	for _, acl := range aclCacheMap {
+		acls = append(acls, acl)
+	}
+	return
+}
+
+func storeAclInCache(a models.Acl) {
+	aclCacheMutex.Lock()
+	defer aclCacheMutex.Unlock()
+	aclCacheMap[a.ID] = a
+}
+
+func removeAclFromCache(a models.Acl) {
+	aclCacheMutex.Lock()
+	defer aclCacheMutex.Unlock()
+	delete(aclCacheMap, a.ID)
+}
+
+func getAclFromCache(aID string) (a models.Acl, ok bool) {
+	aclCacheMutex.RLock()
+	defer aclCacheMutex.RUnlock()
+	a, ok = aclCacheMap[aID]
+	return
+}
+
 // InsertAcl - creates acl policy
 func InsertAcl(a models.Acl) error {
 	d, err := json.Marshal(a)
 	if err != nil {
 		return err
 	}
-	return database.Insert(a.ID, string(d), database.ACLS_TABLE_NAME)
+	err = database.Insert(a.ID, string(d), database.ACLS_TABLE_NAME)
+	if err == nil && servercfg.CacheEnabled() {
+		storeAclInCache(a)
+	}
+	return err
 }
 
 // GetAcl - gets acl info by id
 func GetAcl(aID string) (models.Acl, error) {
 	a := models.Acl{}
+	if servercfg.CacheEnabled() {
+		var ok bool
+		a, ok = getAclFromCache(aID)
+		if ok {
+			return a, nil
+		}
+	}
 	d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID)
 	if err != nil {
 		return a, err
@@ -140,6 +186,9 @@ func GetAcl(aID string) (models.Acl, error) {
 	if err != nil {
 		return a, err
 	}
+	if servercfg.CacheEnabled() {
+		storeAclInCache(a)
+	}
 	return a, nil
 }
 
@@ -254,7 +303,11 @@ func UpdateAcl(newAcl, acl models.Acl) error {
 	if err != nil {
 		return err
 	}
-	return database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME)
+	err = database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME)
+	if err == nil && servercfg.CacheEnabled() {
+		storeAclInCache(acl)
+	}
+	return err
 }
 
 // UpsertAcl - upserts acl
@@ -263,12 +316,20 @@ func UpsertAcl(acl models.Acl) error {
 	if err != nil {
 		return err
 	}
-	return database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME)
+	err = database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME)
+	if err == nil && servercfg.CacheEnabled() {
+		storeAclInCache(acl)
+	}
+	return err
 }
 
 // DeleteAcl - deletes acl policy
 func DeleteAcl(a models.Acl) error {
-	return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID)
+	err := database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID)
+	if err == nil && servercfg.CacheEnabled() {
+		removeAclFromCache(a)
+	}
+	return err
 }
 
 // GetDefaultPolicy - fetches default policy in the network by ruleType
@@ -305,29 +366,45 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
 	return acl, nil
 }
 
-// ListUserPolicies - lists all acl policies enforced on an user
-func ListUserPolicies(u models.User) []models.Acl {
+func listAcls() (acls []models.Acl) {
+	if servercfg.CacheEnabled() && len(aclCacheMap) > 0 {
+		return listAclFromCache()
+	}
+
 	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 		return []models.Acl{}
 	}
-	acls := []models.Acl{}
+
 	for _, dataI := range data {
 		acl := models.Acl{}
 		err := json.Unmarshal([]byte(dataI), &acl)
 		if err != nil {
 			continue
 		}
+		acls = append(acls, acl)
+		if servercfg.CacheEnabled() {
+			storeAclInCache(acl)
+		}
+	}
+	return
+}
+
+// ListUserPolicies - lists all acl policies enforced on an user
+func ListUserPolicies(u models.User) []models.Acl {
+	allAcls := listAcls()
+	userAcls := []models.Acl{}
+	for _, acl := range allAcls {
 
 		if acl.RuleType == models.UserPolicy {
 			srcMap := convAclTagToValueMap(acl.Src)
 			if _, ok := srcMap[u.UserName]; ok {
-				acls = append(acls, acl)
+				userAcls = append(userAcls, acl)
 			} else {
 				// check for user groups
 				for gID := range u.UserGroups {
 					if _, ok := srcMap[gID.String()]; ok {
-						acls = append(acls, acl)
+						userAcls = append(userAcls, acl)
 						break
 					}
 				}
@@ -335,84 +412,61 @@ func ListUserPolicies(u models.User) []models.Acl {
 
 		}
 	}
-	return acls
+	return userAcls
 }
 
 // listPoliciesOfUser - lists all user acl policies applied to user in an network
 func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl {
-	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
-	if err != nil && !database.IsEmptyRecord(err) {
-		return []models.Acl{}
-	}
-	acls := []models.Acl{}
-	for _, dataI := range data {
-		acl := models.Acl{}
-		err := json.Unmarshal([]byte(dataI), &acl)
-		if err != nil {
-			continue
-		}
+	allAcls := listAcls()
+	userAcls := []models.Acl{}
+	for _, acl := range allAcls {
 		if acl.NetworkID == netID && acl.RuleType == models.UserPolicy {
 			srcMap := convAclTagToValueMap(acl.Src)
 			if _, ok := srcMap[user.UserName]; ok {
-				acls = append(acls, acl)
+				userAcls = append(userAcls, acl)
 				continue
 			}
 			for netRole := range user.NetworkRoles {
 				if _, ok := srcMap[netRole.String()]; ok {
-					acls = append(acls, acl)
+					userAcls = append(userAcls, acl)
 					continue
 				}
 			}
 			for userG := range user.UserGroups {
 				if _, ok := srcMap[userG.String()]; ok {
-					acls = append(acls, acl)
+					userAcls = append(userAcls, acl)
 					continue
 				}
 			}
 
 		}
 	}
-	return acls
+	return userAcls
 }
 
 // listDevicePolicies - lists all device policies in a network
 func listDevicePolicies(netID models.NetworkID) []models.Acl {
-	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
-	if err != nil && !database.IsEmptyRecord(err) {
-		return []models.Acl{}
-	}
-	acls := []models.Acl{}
-	for _, dataI := range data {
-		acl := models.Acl{}
-		err := json.Unmarshal([]byte(dataI), &acl)
-		if err != nil {
-			continue
-		}
+	allAcls := listAcls()
+	deviceAcls := []models.Acl{}
+	for _, acl := range allAcls {
 		if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
-			acls = append(acls, acl)
+			deviceAcls = append(deviceAcls, acl)
 		}
 	}
-	return acls
+	return deviceAcls
 }
 
 // ListAcls - lists all acl policies
 func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
-	data, err := database.FetchRecords(database.ACLS_TABLE_NAME)
-	if err != nil && !database.IsEmptyRecord(err) {
-		return []models.Acl{}, err
-	}
-	acls := []models.Acl{}
-	for _, dataI := range data {
-		acl := models.Acl{}
-		err := json.Unmarshal([]byte(dataI), &acl)
-		if err != nil {
-			continue
-		}
+
+	allAcls := listAcls()
+	netAcls := []models.Acl{}
+	for _, acl := range allAcls {
 		if acl.NetworkID == netID {
-			acls = append(acls, acl)
+			netAcls = append(netAcls, acl)
 		}
 	}
-	return acls, nil
+	return netAcls, nil
 }
 
 func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
@@ -458,7 +512,6 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) bool {
 
 // IsNodeAllowedToCommunicate - check node is allowed to communicate with the peer
 func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
-	return true
 	if node.IsStatic {
 		node = node.StaticNode.ConvertToStaticNode()
 	}