Browse Source

cache acls

Abhishek Kondur 2 years ago
parent
commit
27ac920069
7 changed files with 78 additions and 63 deletions
  1. 0 30
      controllers/hosts.go
  2. 35 1
      logic/acls/common.go
  3. 6 1
      logic/acls/nodeacls/modify.go
  4. 1 0
      logic/acls/nodeacls/retrieve.go
  5. 14 8
      logic/hosts.go
  6. 15 15
      logic/networks.go
  7. 7 8
      logic/nodes.go

+ 0 - 30
controllers/hosts.go

@@ -49,38 +49,8 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	//isMasterAdmin := r.Header.Get("ismaster") == "yes"
-	//user, err := logic.GetUser(r.Header.Get("user"))
-	//if err != nil && !isMasterAdmin {
-	//	logger.Log(0, r.Header.Get("user"), "failed to fetch user: ", err.Error())
-	//	logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-	//	return
-	//}
-	// return JSON/API formatted hosts
-	//ret := []models.ApiHost{}
 	apiHosts := logic.GetAllHostsAPI(currentHosts[:])
 	logger.Log(2, r.Header.Get("user"), "fetched all hosts")
-	//for _, host := range apiHosts {
-	//	nodes := host.Nodes
-	//	// work on the copy
-	//	host.Nodes = []string{}
-	//	for _, nid := range nodes {
-	//		node, err := logic.GetNodeByID(nid)
-	//		if err != nil {
-	//			logger.Log(0, r.Header.Get("user"), "failed to fetch node: ", err.Error())
-	//			// TODO find the reason for the DB error, skip this node for now
-	//			continue
-	//		}
-	//		if !isMasterAdmin && !logic.UserHasNetworksAccess([]string{node.Network}, user) {
-	//			continue
-	//		}
-	//		host.Nodes = append(host.Nodes, nid)
-	//	}
-	//	// add to the response only if has perms to some nodes / networks
-	//	if len(host.Nodes) > 0 {
-	//		ret = append(ret, host)
-	//	}
-	//}
 	logic.SortApiHosts(apiHosts[:])
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(apiHosts)

+ 35 - 1
logic/acls/common.go

@@ -2,10 +2,35 @@ package acls
 
 import (
 	"encoding/json"
+	"sync"
 
 	"github.com/gravitl/netmaker/database"
 )
 
+var (
+	aclCacheMutex = &sync.RWMutex{}
+	aclCacheMap   = make(map[ContainerID]ACLContainer)
+)
+
+func fetchAclContainerFromCache(containerID ContainerID) (aclCont ACLContainer, ok bool) {
+	aclCacheMutex.RLock()
+	aclCont, ok = aclCacheMap[containerID]
+	aclCacheMutex.RUnlock()
+	return
+}
+
+func storeAclContainerInCache(containerID ContainerID, aclContainer ACLContainer) {
+	aclCacheMutex.Lock()
+	aclCacheMap[containerID] = aclContainer
+	aclCacheMutex.Unlock()
+}
+
+func DeleteAclFromCache(containerID ContainerID) {
+	aclCacheMutex.Lock()
+	delete(aclCacheMap, containerID)
+	aclCacheMutex.Unlock()
+}
+
 // == type functions ==
 
 // ACL.Allow - allows access by ID in memory
@@ -75,6 +100,9 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err
 
 // fetchACLContainer - fetches all current rules in given ACL container
 func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
+	if aclContainer, ok := fetchAclContainerFromCache(containerID); ok {
+		return aclContainer, nil
+	}
 	aclJson, err := fetchACLContainerJson(ContainerID(containerID))
 	if err != nil {
 		return nil, err
@@ -83,6 +111,7 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
 	if err := json.Unmarshal([]byte(aclJson), &currentNetworkACL); err != nil {
 		return nil, err
 	}
+	storeAclContainerInCache(containerID, currentNetworkACL)
 	return currentNetworkACL, nil
 }
 
@@ -112,7 +141,12 @@ func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACL
 	if aclContainer == nil {
 		aclContainer = make(ACLContainer)
 	}
-	return aclContainer, database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)
+	err := database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)
+	if err != nil {
+		return aclContainer, err
+	}
+	storeAclContainerInCache(containerID, aclContainer)
+	return aclContainer, nil
 }
 
 func convertNetworkACLtoACLJson(networkACL ACLContainer) ACLJson {

+ 6 - 1
logic/acls/nodeacls/modify.go

@@ -83,5 +83,10 @@ func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error
 
 // DeleteACLContainer - removes an ACLContainer state from db
 func DeleteACLContainer(network NetworkID) error {
-	return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
+	err := database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
+	if err != nil {
+		return err
+	}
+	acls.DeleteAclFromCache(acls.ContainerID(network))
+	return nil
 }

+ 1 - 0
logic/acls/nodeacls/retrieve.go

@@ -9,6 +9,7 @@ import (
 
 // AreNodesAllowed - checks if nodes are allowed to communicate in their network ACL
 func AreNodesAllowed(networkID NetworkID, node1, node2 NodeID) bool {
+	return true
 	var currentNetworkACL, err = FetchAllACLs(networkID)
 	if err != nil {
 		return false

+ 14 - 8
logic/hosts.go

@@ -33,7 +33,16 @@ var (
 	ErrInvalidHostID error = errors.New("invalid host id")
 )
 
-func getHostsFromCache() (hostsMap map[string]models.Host) {
+func getHostsFromCache() (hosts []models.Host) {
+	hostCacheMutex.RLock()
+	for _, host := range hostsCacheMap {
+		hosts = append(hosts, host)
+	}
+	hostCacheMutex.RUnlock()
+	return
+}
+
+func getHostsMapFromCache() (hostsMap map[string]models.Host) {
 	hostCacheMutex.RLock()
 	hostsMap = hostsCacheMap
 	hostCacheMutex.RUnlock()
@@ -71,12 +80,9 @@ const (
 
 // GetAllHosts - returns all hosts in flat list or error
 func GetAllHosts() ([]models.Host, error) {
-	currHosts := []models.Host{}
-	hostsMap := getHostsFromCache()
-	if len(hostsMap) != 0 {
-		for _, host := range hostsMap {
-			currHosts = append(currHosts, host)
-		}
+
+	currHosts := getHostsFromCache()
+	if len(currHosts) != 0 {
 		return currHosts, nil
 	}
 	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
@@ -110,7 +116,7 @@ func GetAllHostsAPI(hosts []models.Host) []models.ApiHost {
 
 // GetHostsMap - gets all the current hosts on machine in a map
 func GetHostsMap() (map[string]models.Host, error) {
-	hostsMap := getHostsFromCache()
+	hostsMap := getHostsMapFromCache()
 	if len(hostsMap) != 0 {
 		return hostsMap, nil
 	}

+ 15 - 15
logic/networks.go

@@ -194,18 +194,12 @@ func UniqueAddress(networkName string, reverse bool) (net.IP, error) {
 func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool {
 
 	isunique := true
-	collection, err := database.FetchRecords(tableName)
-	if err != nil {
-		return isunique
-	}
-
-	for _, value := range collection { // filter
-
-		if tableName == database.NODES_TABLE_NAME {
-			var node models.Node
-			if err = json.Unmarshal([]byte(value), &node); err != nil {
-				continue
-			}
+	if tableName == database.NODES_TABLE_NAME {
+		nodes, err := GetNetworkNodes(network)
+		if err != nil {
+			return isunique
+		}
+		for _, node := range nodes {
 			if isIpv6 {
 				if node.Address6.IP.String() == ip && node.Network == network {
 					return false
@@ -215,8 +209,15 @@ func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool {
 					return false
 				}
 			}
-		} else if tableName == database.EXT_CLIENT_TABLE_NAME {
-			var extClient models.ExtClient
+		}
+
+	} else if tableName == database.EXT_CLIENT_TABLE_NAME {
+		collection, err := database.FetchRecords(tableName)
+		if err != nil {
+			return isunique
+		}
+		var extClient models.ExtClient
+		for _, value := range collection { // filter
 			if err = json.Unmarshal([]byte(value), &extClient); err != nil {
 				continue
 			}
@@ -231,7 +232,6 @@ func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool {
 				}
 			}
 		}
-
 	}
 
 	return isunique

+ 7 - 8
logic/nodes.go

@@ -33,9 +33,11 @@ func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
 	nodeCacheMutex.RUnlock()
 	return
 }
-func getNodesFromCache() (nMap map[string]models.Node) {
+func getNodesFromCache() (nodes []models.Node) {
 	nodeCacheMutex.RLock()
-	nMap = nodesCacheMap
+	for _, node := range nodesCacheMap {
+		nodes = append(nodes, node)
+	}
 	nodeCacheMutex.RUnlock()
 	return
 }
@@ -294,14 +296,11 @@ func IsFailoverPresent(network string) bool {
 // GetAllNodes - returns all nodes in the DB
 func GetAllNodes() ([]models.Node, error) {
 	var nodes []models.Node
-	nodesMap := getNodesFromCache()
-	if len(nodesMap) != 0 {
-		for _, node := range nodesMap {
-			nodes = append(nodes, node)
-		}
+	nodes = getNodesFromCache()
+	if len(nodes) != 0 {
 		return nodes, nil
 	}
-	nodesMap = make(map[string]models.Node)
+	nodesMap := make(map[string]models.Node)
 	defer loadNodesIntoCache(nodesMap)
 	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
 	if err != nil {