Explorar o código

added db caching

Tobias Cudnik %!s(int64=2) %!d(string=hai) anos
pai
achega
ddc1c87e6d
Modificáronse 5 ficheiros con 124 adicións e 2 borrados
  1. 47 1
      logic/acls/common.go
  2. 4 0
      logic/acls/nodeacls/modify.go
  3. 3 1
      logic/acls/nodeacls/retrieve.go
  4. 33 0
      logic/hosts.go
  5. 37 0
      logic/nodes.go

+ 47 - 1
logic/acls/common.go

@@ -2,8 +2,10 @@ package acls
 
 import (
 	"encoding/json"
+	"sync"
 
 	"github.com/gravitl/netmaker/database"
+	"golang.org/x/exp/slog"
 )
 
 // == type functions ==
@@ -52,6 +54,22 @@ func (aclContainer ACLContainer) RemoveACL(ID AclID) ACLContainer {
 
 // ACLContainer.ChangeAccess - changes the relationship between two nodes in memory
 func (networkACL ACLContainer) ChangeAccess(ID1, ID2 AclID, value byte) {
+	if _, ok := networkACL[ID1]; !ok {
+		slog.Error("ACL missing for ", "id", ID1)
+		return
+	}
+	if _, ok := networkACL[ID2]; !ok {
+		slog.Error("ACL missing for ", "id", ID2)
+		return
+	}
+	if _, ok := networkACL[ID1][ID2]; !ok {
+		slog.Error("ACL missing for ", "id1", ID1, "id2", ID2)
+		return
+	}
+	if _, ok := networkACL[ID2][ID1]; !ok {
+		slog.Error("ACL missing for ", "id2", ID2, "id1", ID1)
+		return
+	}
 	networkACL[ID1][ID2] = value
 	networkACL[ID2][ID1] = value
 }
@@ -73,9 +91,26 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err
 
 // == private ==
 
+var CacheACL map[ContainerID]ACLContainer
+var CacheACLMutex = sync.RWMutex{}
+
 // fetchACLContainer - fetches all current rules in given ACL container
+// TODO pointer
 func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
-	aclJson, err := fetchACLContainerJson(ContainerID(containerID))
+	CacheACLMutex.RLock()
+	if CacheACL != nil {
+		if _, ok := CacheACL[containerID]; ok {
+			defer CacheACLMutex.RUnlock()
+			return CacheACL[containerID], nil
+		}
+	} else {
+		CacheACLMutex.RUnlock()
+		CacheACLMutex.Lock()
+		CacheACL = make(map[ContainerID]ACLContainer)
+		CacheACLMutex.Unlock()
+	}
+	// TODO cache
+	aclJson, err := fetchACLContainerJson(containerID)
 	if err != nil {
 		return nil, err
 	}
@@ -83,6 +118,9 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
 	if err := json.Unmarshal([]byte(aclJson), &currentNetworkACL); err != nil {
 		return nil, err
 	}
+	CacheACLMutex.Lock()
+	CacheACL[containerID] = currentNetworkACL
+	CacheACLMutex.Unlock()
 	return currentNetworkACL, nil
 }
 
@@ -102,6 +140,10 @@ func upsertACL(containerID ContainerID, ID AclID, acl ACL) (ACL, error) {
 		return acl, err
 	}
 	currentNetACL[ID] = acl
+	// invalidate cache
+	CacheACLMutex.Lock()
+	delete(CacheACL, containerID)
+	CacheACLMutex.Unlock()
 	_, err = upsertACLContainer(containerID, currentNetACL)
 	return acl, err
 }
@@ -112,6 +154,10 @@ func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACL
 	if aclContainer == nil {
 		aclContainer = make(ACLContainer)
 	}
+	// invalidate cache
+	CacheACLMutex.Lock()
+	delete(CacheACL, containerID)
+	CacheACLMutex.Unlock()
 	return aclContainer, database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)
 }
 

+ 4 - 0
logic/acls/nodeacls/modify.go

@@ -83,5 +83,9 @@ func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error
 
 // DeleteACLContainer - removes an ACLContainer state from db
 func DeleteACLContainer(network NetworkID) error {
+	// invalidate cache
+	acls.CacheACLMutex.Lock()
+	acls.CacheACL = nil
+	acls.CacheACLMutex.Unlock()
 	return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
 }

+ 3 - 1
logic/acls/nodeacls/retrieve.go

@@ -13,7 +13,9 @@ func AreNodesAllowed(networkID NetworkID, node1, node2 NodeID) bool {
 	if err != nil {
 		return false
 	}
-	return currentNetworkACL[acls.AclID(node1)].IsAllowed(acls.AclID(node2)) && currentNetworkACL[acls.AclID(node2)].IsAllowed(acls.AclID(node1))
+	nodeID1 := acls.AclID(node1)
+	nodeID2 := acls.AclID(node2)
+	return currentNetworkACL[nodeID1].IsAllowed(nodeID2) && currentNetworkACL[nodeID2].IsAllowed(nodeID1)
 }
 
 // FetchNodeACL - fetches a specific node's ACL in a given network

+ 33 - 0
logic/hosts.go

@@ -10,6 +10,7 @@ import (
 	"net/http"
 	"sort"
 	"strconv"
+	"sync"
 
 	"github.com/devilcove/httpclient"
 	"github.com/google/uuid"
@@ -57,8 +58,17 @@ func GetAllHostsAPI(hosts []models.Host) []models.ApiHost {
 	return apiHosts[:]
 }
 
+var CacheHosts map[string]*models.Host
+var CacheHostsMutex = sync.RWMutex{}
+
 // GetHostsMap - gets all the current hosts on machine in a map
 func GetHostsMap() (map[string]*models.Host, error) {
+	CacheHostsMutex.RLock()
+	if CacheHosts != nil {
+		defer CacheHostsMutex.RUnlock()
+		return CacheHosts, nil
+	}
+	CacheHostsMutex.RUnlock()
 	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 		return nil, err
@@ -72,12 +82,23 @@ func GetHostsMap() (map[string]*models.Host, error) {
 		}
 		currHostMap[h.ID.String()] = &h
 	}
+	CacheHostsMutex.Lock()
+	CacheHosts = currHostMap
+	CacheHostsMutex.Unlock()
 
 	return currHostMap, nil
 }
 
 // GetHost - gets a host from db given id
 func GetHost(hostid string) (*models.Host, error) {
+	CacheHostsMutex.RLock()
+	if CacheHosts != nil {
+		if _, ok := CacheHosts[hostid]; ok {
+			defer CacheHostsMutex.RUnlock()
+			return CacheHosts[hostid], nil
+		}
+	}
+	CacheHostsMutex.RUnlock()
 	record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid)
 	if err != nil {
 		return nil, err
@@ -216,6 +237,10 @@ func UpsertHost(h *models.Host) error {
 		return err
 	}
 
+	// invalidate cache
+	CacheHostsMutex.Lock()
+	CacheHosts = nil
+	CacheHostsMutex.Unlock()
 	return database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME)
 }
 
@@ -228,6 +253,10 @@ func RemoveHost(h *models.Host) error {
 		DeRegisterHostWithTurn(h.ID.String())
 	}
 
+	// invalidate cache
+	CacheHostsMutex.Lock()
+	CacheHosts = nil
+	CacheHostsMutex.Unlock()
 	return database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String())
 }
 
@@ -236,6 +265,10 @@ func RemoveHostByID(hostID string) error {
 	if servercfg.IsUsingTurn() {
 		DeRegisterHostWithTurn(hostID)
 	}
+	// invalidate cache
+	CacheHostsMutex.Lock()
+	CacheHosts = nil
+	CacheHostsMutex.Unlock()
 	return database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID)
 }
 

+ 37 - 0
logic/nodes.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net"
 	"sort"
+	"sync"
 	"time"
 
 	validator "github.com/go-playground/validator/v10"
@@ -104,6 +105,10 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
 		if data, err := json.Marshal(newNode); err != nil {
 			return err
 		} else {
+			// invalidate cache
+			CacheNodesMutex.Lock()
+			CacheNodes = nil
+			CacheNodesMutex.Unlock()
 			return database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME)
 		}
 	}
@@ -157,6 +162,10 @@ func deleteNodeByID(node *models.Node) error {
 			logger.Log(0, "failed to deleted ext clients", err.Error())
 		}
 	}
+	// invalidate cache
+	CacheNodesMutex.Lock()
+	CacheNodes = nil
+	CacheNodesMutex.Unlock()
 	if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil {
 		if !database.IsEmptyRecord(err) {
 			return err
@@ -224,8 +233,17 @@ func IsFailoverPresent(network string) bool {
 	return false
 }
 
+var CacheNodes []models.Node
+var CacheNodesMutex = sync.RWMutex{}
+
 // GetAllNodes - returns all nodes in the DB
 func GetAllNodes() ([]models.Node, error) {
+	CacheNodesMutex.RLock()
+	if CacheNodes != nil {
+		defer CacheNodesMutex.RUnlock()
+		return CacheNodes, nil
+	}
+	CacheNodesMutex.RUnlock()
 	var nodes []models.Node
 
 	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
@@ -247,6 +265,10 @@ func GetAllNodes() ([]models.Node, error) {
 		nodes = append(nodes, node)
 	}
 
+	CacheNodesMutex.Lock()
+	CacheNodes = nodes
+	CacheNodesMutex.Unlock()
+
 	return nodes, nil
 }
 
@@ -366,7 +388,18 @@ func GetNodeRelay(network string, relayedNodeAddr string) (models.Node, error) {
 	return relay, errors.New(RELAY_NODE_ERR + " " + relayedNodeAddr)
 }
 
+// TODO pointer
 func GetNodeByID(uuid string) (models.Node, error) {
+	CacheNodesMutex.RLock()
+	if CacheNodes != nil {
+		for _, node := range CacheNodes {
+			if node.ID.String() == uuid {
+				defer CacheNodesMutex.RUnlock()
+				return node, nil
+			}
+		}
+	}
+	CacheNodesMutex.RUnlock()
 	var record, err = database.FetchRecord(database.NODES_TABLE_NAME, uuid)
 	if err != nil {
 		return models.Node{}, err
@@ -532,6 +565,10 @@ func createNode(node *models.Node) error {
 	if err != nil {
 		return err
 	}
+	// invalidate cache
+	CacheNodesMutex.Lock()
+	CacheNodes = nil
+	CacheNodesMutex.Unlock()
 	err = database.Insert(node.ID.String(), string(nodebytes), database.NODES_TABLE_NAME)
 	if err != nil {
 		return err