Browse Source

Merge pull request #2427 from gravitl/NET-390-acl-panic-fix

NET-390: acl panic fix, DB cache
Alex Feiszli 2 years ago
parent
commit
ae92499a32

+ 1 - 2
controllers/dns_test.go

@@ -51,8 +51,7 @@ func TestGetNodeDNS(t *testing.T) {
 	createNet()
 	createHost()
 	t.Run("NoNodes", func(t *testing.T) {
-		dns, err := logic.GetNodeDNS("skynet")
-		assert.EqualError(t, err, "could not find any records")
+		dns, _ := logic.GetNodeDNS("skynet")
 		assert.Equal(t, []models.DNSEntry(nil), dns)
 	})
 	t.Run("NodeExists", func(t *testing.T) {

+ 1 - 2
controllers/ext_client.go

@@ -10,7 +10,6 @@ import (
 
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/functions"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/pro"
@@ -102,7 +101,7 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) {
 	clients := []models.ExtClient{}
 	var err error
 	if len(networksSlice) > 0 && networksSlice[0] == logic.ALL_NETWORK_ACCESS {
-		clients, err = functions.GetAllExtClients()
+		clients, err = logic.GetAllExtClients()
 		if err != nil && !database.IsEmptyRecord(err) {
 			logger.Log(0, "failed to get all extclients: ", err.Error())
 			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))

+ 0 - 30
controllers/hosts.go

@@ -49,38 +49,8 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	//isMasterAdmin := r.Header.Get("ismaster") == "yes"
-	//user, err := logic.GetUser(r.Header.Get("user"))
-	//if err != nil && !isMasterAdmin {
-	//	logger.Log(0, r.Header.Get("user"), "failed to fetch user: ", err.Error())
-	//	logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-	//	return
-	//}
-	// return JSON/API formatted hosts
-	//ret := []models.ApiHost{}
 	apiHosts := logic.GetAllHostsAPI(currentHosts[:])
 	logger.Log(2, r.Header.Get("user"), "fetched all hosts")
-	//for _, host := range apiHosts {
-	//	nodes := host.Nodes
-	//	// work on the copy
-	//	host.Nodes = []string{}
-	//	for _, nid := range nodes {
-	//		node, err := logic.GetNodeByID(nid)
-	//		if err != nil {
-	//			logger.Log(0, r.Header.Get("user"), "failed to fetch node: ", err.Error())
-	//			// TODO find the reason for the DB error, skip this node for now
-	//			continue
-	//		}
-	//		if !isMasterAdmin && !logic.UserHasNetworksAccess([]string{node.Network}, user) {
-	//			continue
-	//		}
-	//		host.Nodes = append(host.Nodes, nid)
-	//	}
-	//	// add to the response only if has perms to some nodes / networks
-	//	if len(host.Nodes) > 0 {
-	//		ret = append(ret, host)
-	//	}
-	//}
 	logic.SortApiHosts(apiHosts[:])
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(apiHosts)

+ 1 - 0
controllers/node_test.go

@@ -217,6 +217,7 @@ func TestNodeACLs(t *testing.T) {
 }
 
 func deleteAllNodes() {
+	logic.ClearNodeCache()
 	database.DeleteAllRecords(database.NODES_TABLE_NAME)
 }
 

+ 58 - 1
logic/acls/common.go

@@ -2,10 +2,37 @@ package acls
 
 import (
 	"encoding/json"
+	"sync"
 
 	"github.com/gravitl/netmaker/database"
+	"golang.org/x/exp/slog"
 )
 
+var (
+	aclCacheMutex = &sync.RWMutex{}
+	aclCacheMap   = make(map[ContainerID]ACLContainer)
+	aclMutex      = &sync.RWMutex{}
+)
+
+func fetchAclContainerFromCache(containerID ContainerID) (aclCont ACLContainer, ok bool) {
+	aclCacheMutex.RLock()
+	aclCont, ok = aclCacheMap[containerID]
+	aclCacheMutex.RUnlock()
+	return
+}
+
+func storeAclContainerInCache(containerID ContainerID, aclContainer ACLContainer) {
+	aclCacheMutex.Lock()
+	aclCacheMap[containerID] = aclContainer
+	aclCacheMutex.Unlock()
+}
+
+func DeleteAclFromCache(containerID ContainerID) {
+	aclCacheMutex.Lock()
+	delete(aclCacheMap, containerID)
+	aclCacheMutex.Unlock()
+}
+
 // == type functions ==
 
 // ACL.Allow - allows access by ID in memory
@@ -52,6 +79,22 @@ func (aclContainer ACLContainer) RemoveACL(ID AclID) ACLContainer {
 
 // ACLContainer.ChangeAccess - changes the relationship between two nodes in memory
 func (networkACL ACLContainer) ChangeAccess(ID1, ID2 AclID, value byte) {
+	if _, ok := networkACL[ID1]; !ok {
+		slog.Error("ACL missing for ", "id", ID1)
+		return
+	}
+	if _, ok := networkACL[ID2]; !ok {
+		slog.Error("ACL missing for ", "id", ID2)
+		return
+	}
+	if _, ok := networkACL[ID1][ID2]; !ok {
+		slog.Error("ACL missing for ", "id1", ID1, "id2", ID2)
+		return
+	}
+	if _, ok := networkACL[ID2][ID1]; !ok {
+		slog.Error("ACL missing for ", "id2", ID2, "id1", ID1)
+		return
+	}
 	networkACL[ID1][ID2] = value
 	networkACL[ID2][ID1] = value
 }
@@ -75,6 +118,11 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err
 
 // fetchACLContainer - fetches all current rules in given ACL container
 func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
+	aclMutex.RLock()
+	defer aclMutex.RUnlock()
+	if aclContainer, ok := fetchAclContainerFromCache(containerID); ok {
+		return aclContainer, nil
+	}
 	aclJson, err := fetchACLContainerJson(ContainerID(containerID))
 	if err != nil {
 		return nil, err
@@ -83,6 +131,7 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
 	if err := json.Unmarshal([]byte(aclJson), &currentNetworkACL); err != nil {
 		return nil, err
 	}
+	storeAclContainerInCache(containerID, currentNetworkACL)
 	return currentNetworkACL, nil
 }
 
@@ -109,10 +158,18 @@ func upsertACL(containerID ContainerID, ID AclID, acl ACL) (ACL, error) {
 // upsertACLContainer - Inserts or updates a network ACL given the json string of the ACL and the container ID
 // if nil, create it
 func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACLContainer, error) {
+	aclMutex.Lock()
+	defer aclMutex.Unlock()
 	if aclContainer == nil {
 		aclContainer = make(ACLContainer)
 	}
-	return aclContainer, database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)
+
+	err := database.Insert(string(containerID), string(convertNetworkACLtoACLJson(aclContainer)), database.NODE_ACLS_TABLE_NAME)
+	if err != nil {
+		return aclContainer, err
+	}
+	storeAclContainerInCache(containerID, aclContainer)
+	return aclContainer, nil
 }
 
 func convertNetworkACLtoACLJson(networkACL ACLContainer) ACLJson {

+ 6 - 1
logic/acls/nodeacls/modify.go

@@ -83,5 +83,10 @@ func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error
 
 // DeleteACLContainer - removes an ACLContainer state from db
 func DeleteACLContainer(network NetworkID) error {
-	return database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
+	err := database.DeleteRecord(database.NODE_ACLS_TABLE_NAME, string(network))
+	if err != nil {
+		return err
+	}
+	acls.DeleteAclFromCache(acls.ContainerID(network))
+	return nil
 }

+ 2 - 6
logic/dns.go

@@ -69,16 +69,12 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) {
 
 	var dns []models.DNSEntry
 
-	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
+	nodes, err := GetNetworkNodes(network)
 	if err != nil {
 		return dns, err
 	}
 
-	for _, value := range collection {
-		var node models.Node
-		if err = json.Unmarshal([]byte(value), &node); err != nil {
-			continue
-		}
+	for _, node := range nodes {
 		if node.Network != network {
 			continue
 		}

+ 53 - 35
logic/extpeers.go

@@ -3,58 +3,56 @@ package logic
 import (
 	"encoding/json"
 	"fmt"
+	"sync"
 	"time"
 
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
-// GetExtPeersList - gets the ext peers lists
-func GetExtPeersList(node *models.Node) ([]models.ExtPeersResponse, error) {
-
-	var peers []models.ExtPeersResponse
-	records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
+var (
+	extClientCacheMutex = &sync.RWMutex{}
+	extClientCacheMap   = make(map[string]models.ExtClient)
+)
 
-	if err != nil {
-		return peers, err
+func getAllExtClientsFromCache() (extClients []models.ExtClient) {
+	extClientCacheMutex.RLock()
+	for _, extclient := range extClientCacheMap {
+		extClients = append(extClients, extclient)
 	}
+	extClientCacheMutex.RUnlock()
+	return
+}
 
-	for _, value := range records {
-		var peer models.ExtPeersResponse
-		var extClient models.ExtClient
-		err = json.Unmarshal([]byte(value), &peer)
-		if err != nil {
-			logger.Log(2, "failed to unmarshal peer when getting ext peer list")
-			continue
-		}
-		err = json.Unmarshal([]byte(value), &extClient)
-		if err != nil {
-			logger.Log(2, "failed to unmarshal ext client")
-			continue
-		}
+func deleteExtClientFromCache(key string) {
+	extClientCacheMutex.Lock()
+	delete(extClientCacheMap, key)
+	extClientCacheMutex.Unlock()
+}
 
-		if extClient.Enabled && extClient.Network == node.Network && extClient.IngressGatewayID == node.ID.String() {
-			peers = append(peers, peer)
-		}
-	}
-	return peers, err
+func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) {
+	extClientCacheMutex.RLock()
+	extclient, ok = extClientCacheMap[key]
+	extClientCacheMutex.RUnlock()
+	return
+}
+
+func storeExtClientInCache(key string, extclient models.ExtClient) {
+	extClientCacheMutex.Lock()
+	extClientCacheMap[key] = extclient
+	extClientCacheMutex.Unlock()
 }
 
 // ExtClient.GetEgressRangesOnNetwork - returns the egress ranges on network of ext client
 func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
 
 	var result []string
-	nodesData, err := database.FetchRecords(database.NODES_TABLE_NAME)
+	networkNodes, err := GetNetworkNodes(client.Network)
 	if err != nil {
 		return []string{}, err
 	}
-	for _, nodeData := range nodesData {
-		var currentNode models.Node
-		if err = json.Unmarshal([]byte(nodeData), &currentNode); err != nil {
-			continue
-		}
+	for _, currentNode := range networkNodes {
 		if currentNode.Network != client.Network {
 			continue
 		}
@@ -75,13 +73,25 @@ func DeleteExtClient(network string, clientid string) error {
 		return err
 	}
 	err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key)
-	return err
+	if err != nil {
+		return err
+	}
+	deleteExtClientFromCache(key)
+	return nil
 }
 
 // GetNetworkExtClients - gets the ext clients of given network
 func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
 	var extclients []models.ExtClient
-
+	allextclients := getAllExtClientsFromCache()
+	if len(allextclients) != 0 {
+		for _, extclient := range allextclients {
+			if extclient.Network == network {
+				extclients = append(extclients, extclient)
+			}
+		}
+		return extclients, nil
+	}
 	records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
 	if err != nil {
 		return extclients, err
@@ -92,6 +102,10 @@ func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
 		if err != nil {
 			continue
 		}
+		key, err := GetRecordKey(extclient.ClientID, network)
+		if err == nil {
+			storeExtClientInCache(key, extclient)
+		}
 		if extclient.Network == network {
 			extclients = append(extclients, extclient)
 		}
@@ -106,12 +120,15 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) {
 	if err != nil {
 		return extclient, err
 	}
+	if extclient, ok := getExtClientFromCache(key); ok {
+		return extclient, nil
+	}
 	data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
 	if err != nil {
 		return extclient, err
 	}
 	err = json.Unmarshal([]byte(data), &extclient)
-
+	storeExtClientInCache(key, extclient)
 	return extclient, err
 }
 
@@ -190,6 +207,7 @@ func SaveExtClient(extclient *models.ExtClient) error {
 	if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
 		return err
 	}
+	storeExtClientInCache(key, *extclient)
 	return SetNetworkNodesLastModified(extclient.Network)
 }
 

+ 4 - 23
logic/gateway.go

@@ -1,7 +1,6 @@
 package logic
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"time"
@@ -53,11 +52,7 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
 	node.EgressGatewayNatEnabled = models.ParseBool(gateway.NatEnabled)
 	node.EgressGatewayRequest = gateway // store entire request for use when preserving the egress gateway
 	node.SetLastModified()
-	nodeData, err := json.Marshal(&node)
-	if err != nil {
-		return node, err
-	}
-	if err = database.Insert(node.ID.String(), string(nodeData), database.NODES_TABLE_NAME); err != nil {
+	if err = UpsertNode(&node); err != nil {
 		return models.Node{}, err
 	}
 	return node, nil
@@ -84,12 +79,7 @@ func DeleteEgressGateway(network, nodeid string) (models.Node, error) {
 	node.EgressGatewayRanges = []string{}
 	node.EgressGatewayRequest = models.EgressGatewayRequest{} // remove preserved request as the egress gateway is gone
 	node.SetLastModified()
-
-	data, err := json.Marshal(&node)
-	if err != nil {
-		return models.Node{}, err
-	}
-	if err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil {
+	if err = UpsertNode(&node); err != nil {
 		return models.Node{}, err
 	}
 	return node, nil
@@ -128,11 +118,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
 	if ingress.Failover && servercfg.Is_EE {
 		node.Failover = true
 	}
-	data, err := json.Marshal(&node)
-	if err != nil {
-		return models.Node{}, err
-	}
-	err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
+	err = UpsertNode(&node)
 	if err != nil {
 		return models.Node{}, err
 	}
@@ -173,12 +159,7 @@ func DeleteIngressGateway(networkName string, nodeid string) (models.Node, bool,
 				node.EgressGatewayRequest.NodeID, node.EgressGatewayRequest.NetID, err))
 		}
 	}
-
-	data, err := json.Marshal(&node)
-	if err != nil {
-		return models.Node{}, false, removedClients, err
-	}
-	err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
+	err = UpsertNode(&node)
 	if err != nil {
 		return models.Node{}, wasFailover, removedClients, err
 	}

+ 96 - 16
logic/hosts.go

@@ -10,6 +10,7 @@ import (
 	"net/http"
 	"sort"
 	"strconv"
+	"sync"
 
 	"github.com/devilcove/httpclient"
 	"github.com/google/uuid"
@@ -20,6 +21,11 @@ import (
 	"golang.org/x/crypto/bcrypt"
 )
 
+var (
+	hostCacheMutex = &sync.RWMutex{}
+	hostsCacheMap  = make(map[string]models.Host)
+)
+
 var (
 	// ErrHostExists error indicating that host exists when trying to create new host
 	ErrHostExists error = errors.New("host already exists")
@@ -27,6 +33,46 @@ var (
 	ErrInvalidHostID error = errors.New("invalid host id")
 )
 
+func getHostsFromCache() (hosts []models.Host) {
+	hostCacheMutex.RLock()
+	for _, host := range hostsCacheMap {
+		hosts = append(hosts, host)
+	}
+	hostCacheMutex.RUnlock()
+	return
+}
+
+func getHostsMapFromCache() (hostsMap map[string]models.Host) {
+	hostCacheMutex.RLock()
+	hostsMap = hostsCacheMap
+	hostCacheMutex.RUnlock()
+	return
+}
+
+func getHostFromCache(hostID string) (host models.Host, ok bool) {
+	hostCacheMutex.RLock()
+	host, ok = hostsCacheMap[hostID]
+	hostCacheMutex.RUnlock()
+	return
+}
+
+func storeHostInCache(h models.Host) {
+	hostCacheMutex.Lock()
+	hostsCacheMap[h.ID.String()] = h
+	hostCacheMutex.Unlock()
+}
+
+func deleteHostFromCache(hostID string) {
+	hostCacheMutex.Lock()
+	delete(hostsCacheMap, hostID)
+	hostCacheMutex.Unlock()
+}
+func loadHostsIntoCache(hMap map[string]models.Host) {
+	hostCacheMutex.Lock()
+	hostsCacheMap = hMap
+	hostCacheMutex.Unlock()
+}
+
 const (
 	maxPort = 1<<16 - 1
 	minPort = 1025
@@ -34,17 +80,28 @@ const (
 
 // GetAllHosts - returns all hosts in flat list or error
 func GetAllHosts() ([]models.Host, error) {
-	currHostMap, err := GetHostsMap()
-	if err != nil {
+
+	currHosts := getHostsFromCache()
+	if len(currHosts) != 0 {
+		return currHosts, nil
+	}
+	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
+	if err != nil && !database.IsEmptyRecord(err) {
 		return nil, err
 	}
-	var currentHosts = []models.Host{}
-	for k := range currHostMap {
-		var h = *currHostMap[k]
-		currentHosts = append(currentHosts, h)
+	currHostsMap := make(map[string]models.Host)
+	defer loadHostsIntoCache(currHostsMap)
+	for k := range records {
+		var h models.Host
+		err = json.Unmarshal([]byte(records[k]), &h)
+		if err != nil {
+			return nil, err
+		}
+		currHosts = append(currHosts, h)
+		currHostsMap[h.ID.String()] = h
 	}
 
-	return currentHosts, nil
+	return currHosts, nil
 }
 
 // GetAllHostsAPI - get's all the hosts in an API usable format
@@ -58,19 +115,24 @@ func GetAllHostsAPI(hosts []models.Host) []models.ApiHost {
 }
 
 // GetHostsMap - gets all the current hosts on machine in a map
-func GetHostsMap() (map[string]*models.Host, error) {
+func GetHostsMap() (map[string]models.Host, error) {
+	hostsMap := getHostsMapFromCache()
+	if len(hostsMap) != 0 {
+		return hostsMap, nil
+	}
 	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 		return nil, err
 	}
-	currHostMap := make(map[string]*models.Host)
+	currHostMap := make(map[string]models.Host)
+	defer loadHostsIntoCache(currHostMap)
 	for k := range records {
 		var h models.Host
 		err = json.Unmarshal([]byte(records[k]), &h)
 		if err != nil {
 			return nil, err
 		}
-		currHostMap[h.ID.String()] = &h
+		currHostMap[h.ID.String()] = h
 	}
 
 	return currHostMap, nil
@@ -78,6 +140,10 @@ func GetHostsMap() (map[string]*models.Host, error) {
 
 // GetHost - gets a host from db given id
 func GetHost(hostid string) (*models.Host, error) {
+
+	if host, ok := getHostFromCache(hostid); ok {
+		return &host, nil
+	}
 	record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid)
 	if err != nil {
 		return nil, err
@@ -87,7 +153,7 @@ func GetHost(hostid string) (*models.Host, error) {
 	if err = json.Unmarshal([]byte(record), &h); err != nil {
 		return nil, err
 	}
-
+	storeHostInCache(h)
 	return &h, nil
 }
 
@@ -221,8 +287,12 @@ func UpsertHost(h *models.Host) error {
 	if err != nil {
 		return err
 	}
-
-	return database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME)
+	err = database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME)
+	if err != nil {
+		return err
+	}
+	storeHostInCache(*h)
+	return nil
 }
 
 // RemoveHost - removes a given host from server
@@ -233,8 +303,12 @@ func RemoveHost(h *models.Host) error {
 	if servercfg.IsUsingTurn() {
 		DeRegisterHostWithTurn(h.ID.String())
 	}
-
-	return database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String())
+	err := database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String())
+	if err != nil {
+		return err
+	}
+	deleteHostFromCache(h.ID.String())
+	return nil
 }
 
 // RemoveHostByID - removes a given host by id from server
@@ -242,7 +316,13 @@ func RemoveHostByID(hostID string) error {
 	if servercfg.IsUsingTurn() {
 		DeRegisterHostWithTurn(hostID)
 	}
-	return database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID)
+
+	err := database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID)
+	if err != nil {
+		return err
+	}
+	deleteHostFromCache(hostID)
+	return nil
 }
 
 // UpdateHostNetwork - adds/deletes host from a network

+ 17 - 179
logic/networks.go

@@ -115,24 +115,8 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 
 // GetNetworkNonServerNodeCount - get number of network non server nodes
 func GetNetworkNonServerNodeCount(networkName string) (int, error) {
-
-	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
-	count := 0
-	if err != nil && !database.IsEmptyRecord(err) {
-		return count, err
-	}
-	for _, value := range collection {
-		var node models.Node
-		if err = json.Unmarshal([]byte(value), &node); err != nil {
-			return count, err
-		} else {
-			if node.Network == networkName {
-				count++
-			}
-		}
-	}
-
-	return count, nil
+	nodes, err := GetNetworkNodes(networkName)
+	return len(nodes), err
 }
 
 // GetParentNetwork - get parent network
@@ -210,18 +194,12 @@ func UniqueAddress(networkName string, reverse bool) (net.IP, error) {
 func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool {
 
 	isunique := true
-	collection, err := database.FetchRecords(tableName)
-	if err != nil {
-		return isunique
-	}
-
-	for _, value := range collection { // filter
-
-		if tableName == database.NODES_TABLE_NAME {
-			var node models.Node
-			if err = json.Unmarshal([]byte(value), &node); err != nil {
-				continue
-			}
+	if tableName == database.NODES_TABLE_NAME {
+		nodes, err := GetNetworkNodes(network)
+		if err != nil {
+			return isunique
+		}
+		for _, node := range nodes {
 			if isIpv6 {
 				if node.Address6.IP.String() == ip && node.Network == network {
 					return false
@@ -231,11 +209,15 @@ func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool {
 					return false
 				}
 			}
-		} else if tableName == database.EXT_CLIENT_TABLE_NAME {
-			var extClient models.ExtClient
-			if err = json.Unmarshal([]byte(value), &extClient); err != nil {
-				continue
-			}
+		}
+
+	} else if tableName == database.EXT_CLIENT_TABLE_NAME {
+
+		extClients, err := GetNetworkExtClients(network)
+		if err != nil {
+			return isunique
+		}
+		for _, extClient := range extClients { // filter
 			if isIpv6 {
 				if (extClient.Address6 == ip) && extClient.Network == network {
 					return false
@@ -247,7 +229,6 @@ func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool {
 				}
 			}
 		}
-
 	}
 
 	return isunique
@@ -298,149 +279,6 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) {
 	return add, errors.New("ERROR: No unique IPv6 addresses available. Check network subnet")
 }
 
-// UpdateNetworkLocalAddresses - updates network localaddresses
-func UpdateNetworkLocalAddresses(networkName string) error {
-
-	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
-
-	if err != nil {
-		return err
-	}
-
-	for _, value := range collection {
-
-		var node models.Node
-
-		err := json.Unmarshal([]byte(value), &node)
-		if err != nil {
-			fmt.Println("error in node address assignment!")
-			return err
-		}
-		if node.Network == networkName {
-			var ipaddr net.IP
-			var iperr error
-			ipaddr, iperr = UniqueAddress(networkName, false)
-			if iperr != nil {
-				fmt.Println("error in node  address assignment!")
-				return iperr
-			}
-
-			node.Address.IP = ipaddr
-			newNodeData, err := json.Marshal(&node)
-			if err != nil {
-				logger.Log(1, "error in node  address assignment!")
-				return err
-			}
-			database.Insert(node.ID.String(), string(newNodeData), database.NODES_TABLE_NAME)
-		}
-	}
-
-	return nil
-}
-
-// RemoveNetworkNodeIPv6Addresses - removes network node IPv6 addresses
-func RemoveNetworkNodeIPv6Addresses(networkName string) error {
-
-	collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
-	if err != nil {
-		return err
-	}
-
-	for _, value := range collections {
-
-		var node models.Node
-		err := json.Unmarshal([]byte(value), &node)
-		if err != nil {
-			fmt.Println("error in node address assignment!")
-			return err
-		}
-		if node.Network == networkName {
-			node.Address6.IP = nil
-			data, err := json.Marshal(&node)
-			if err != nil {
-				return err
-			}
-			database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
-		}
-	}
-
-	return nil
-}
-
-// UpdateNetworkNodeAddresses - updates network node addresses
-func UpdateNetworkNodeAddresses(networkName string) error {
-
-	collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
-	if err != nil {
-		return err
-	}
-
-	for _, value := range collections {
-
-		var node models.Node
-		err := json.Unmarshal([]byte(value), &node)
-		if err != nil {
-			logger.Log(1, "error in node ipv4 address assignment!")
-			return err
-		}
-		if node.Network == networkName {
-			var ipaddr net.IP
-			var iperr error
-			ipaddr, iperr = UniqueAddress(networkName, false)
-			if iperr != nil {
-				logger.Log(1, "error in node ipv4 address assignment!")
-				return iperr
-			}
-
-			node.Address.IP = ipaddr
-			data, err := json.Marshal(&node)
-			if err != nil {
-				return err
-			}
-			database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
-		}
-	}
-
-	return nil
-}
-
-// UpdateNetworkNodeAddresses6 - updates network node addresses
-func UpdateNetworkNodeAddresses6(networkName string) error {
-
-	collections, err := database.FetchRecords(database.NODES_TABLE_NAME)
-	if err != nil {
-		return err
-	}
-
-	for _, value := range collections {
-
-		var node models.Node
-		err := json.Unmarshal([]byte(value), &node)
-		if err != nil {
-			logger.Log(1, "error in node ipv6 address assignment!")
-			return err
-		}
-		if node.Network == networkName {
-			var ipaddr net.IP
-			var iperr error
-			ipaddr, iperr = UniqueAddress6(networkName, false)
-			if iperr != nil {
-				logger.Log(1, "error in node ipv6 address assignment!")
-				return iperr
-			}
-
-			node.Address6.IP = ipaddr
-			data, err := json.Marshal(&node)
-			if err != nil {
-				return err
-			}
-			database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
-		}
-	}
-
-	return nil
-}
-
 // IsNetworkNameUnique - checks to see if any other networks have the same name (id)
 func IsNetworkNameUnique(network *models.Network) (bool, error) {
 

+ 75 - 45
logic/nodes.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net"
 	"sort"
+	"sync"
 	"time"
 
 	validator "github.com/go-playground/validator/v10"
@@ -17,11 +18,53 @@ import (
 	"github.com/gravitl/netmaker/logic/pro"
 	"github.com/gravitl/netmaker/logic/pro/proacls"
 	"github.com/gravitl/netmaker/models"
-	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/validation"
 )
 
+var (
+	nodeCacheMutex = &sync.RWMutex{}
+	nodesCacheMap  = make(map[string]models.Node)
+)
+
+func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
+	nodeCacheMutex.RLock()
+	node, ok = nodesCacheMap[nodeID]
+	nodeCacheMutex.RUnlock()
+	return
+}
+func getNodesFromCache() (nodes []models.Node) {
+	nodeCacheMutex.RLock()
+	for _, node := range nodesCacheMap {
+		nodes = append(nodes, node)
+	}
+	nodeCacheMutex.RUnlock()
+	return
+}
+
+func deleteNodeFromCache(nodeID string) {
+	nodeCacheMutex.Lock()
+	delete(nodesCacheMap, nodeID)
+	nodeCacheMutex.Unlock()
+}
+
+func storeNodeInCache(node models.Node) {
+	nodeCacheMutex.Lock()
+	nodesCacheMap[node.ID.String()] = node
+	nodeCacheMutex.Unlock()
+}
+
+func loadNodesIntoCache(nMap map[string]models.Node) {
+	nodeCacheMutex.Lock()
+	nodesCacheMap = nMap
+	nodeCacheMutex.Unlock()
+}
+func ClearNodeCache() {
+	nodeCacheMutex.Lock()
+	nodesCacheMap = make(map[string]models.Node)
+	nodeCacheMutex.Unlock()
+}
+
 const (
 	// RELAY_NODE_ERR - error to return if relay node is unfound
 	RELAY_NODE_ERR = "could not find relay for node"
@@ -72,7 +115,12 @@ func UpdateNodeCheckin(node *models.Node) error {
 	if err != nil {
 		return err
 	}
-	return database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
+	err = database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME)
+	if err != nil {
+		return err
+	}
+	storeNodeInCache(*node)
+	return nil
 }
 
 // UpsertNode - updates node in the DB
@@ -82,7 +130,12 @@ func UpsertNode(newNode *models.Node) error {
 	if err != nil {
 		return err
 	}
-	return database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME)
+	err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME)
+	if err != nil {
+		return err
+	}
+	storeNodeInCache(*newNode)
+	return nil
 }
 
 // UpdateNode - takes a node and updates another node with it's values
@@ -114,7 +167,12 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
 		if data, err := json.Marshal(newNode); err != nil {
 			return err
 		} else {
-			return database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME)
+			err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME)
+			if err != nil {
+				return err
+			}
+			storeNodeInCache(*newNode)
+			return nil
 		}
 	}
 
@@ -172,6 +230,7 @@ func deleteNodeByID(node *models.Node) error {
 			return err
 		}
 	}
+	deleteNodeFromCache(node.ID.String())
 	if servercfg.IsDNSMode() {
 		SetDNS()
 	}
@@ -237,7 +296,12 @@ func IsFailoverPresent(network string) bool {
 // GetAllNodes - returns all nodes in the DB
 func GetAllNodes() ([]models.Node, error) {
 	var nodes []models.Node
-
+	nodes = getNodesFromCache()
+	if len(nodes) != 0 {
+		return nodes, nil
+	}
+	nodesMap := make(map[string]models.Node)
+	defer loadNodesIntoCache(nodesMap)
 	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
 	if err != nil {
 		if database.IsEmptyRecord(err) {
@@ -255,6 +319,7 @@ func GetAllNodes() ([]models.Node, error) {
 		}
 		// add node to our array
 		nodes = append(nodes, node)
+		nodesMap[node.ID.String()] = node
 	}
 
 	return nodes, nil
@@ -309,46 +374,10 @@ func GetRecordKey(id string, network string) (string, error) {
 	return id + "###" + network, nil
 }
 
-// GetNodesByAddress - gets a node by mac address
-func GetNodesByAddress(network string, addresses []string) ([]models.Node, error) {
-	var nodes []models.Node
-	allnodes, err := GetAllNodes()
-	if err != nil {
-		return []models.Node{}, err
-	}
-	for _, node := range allnodes {
-		if node.Network == network && ncutils.StringSliceContains(addresses, node.Address.String()) {
-			nodes = append(nodes, node)
-		}
-	}
-	return nodes, nil
-}
-
-// GetDeletedNodeByMacAddress - get a deleted node
-func GetDeletedNodeByMacAddress(network string, macaddress string) (models.Node, error) {
-
-	var node models.Node
-
-	key, err := GetRecordKey(macaddress, network)
-	if err != nil {
-		return node, err
-	}
-
-	record, err := database.FetchRecord(database.DELETED_NODES_TABLE_NAME, key)
-	if err != nil {
-		return models.Node{}, err
-	}
-
-	if err = json.Unmarshal([]byte(record), &node); err != nil {
-		return models.Node{}, err
-	}
-
-	SetNodeDefaults(&node)
-
-	return node, nil
-}
-
 func GetNodeByID(uuid string) (models.Node, error) {
+	if node, ok := getNodeFromCache(uuid); ok {
+		return node, nil
+	}
 	var record, err = database.FetchRecord(database.NODES_TABLE_NAME, uuid)
 	if err != nil {
 		return models.Node{}, err
@@ -357,6 +386,7 @@ func GetNodeByID(uuid string) (models.Node, error) {
 	if err = json.Unmarshal([]byte(record), &node); err != nil {
 		return models.Node{}, err
 	}
+	storeNodeInCache(node)
 	return node, nil
 }
 
@@ -506,7 +536,7 @@ func createNode(node *models.Node) error {
 	if err != nil {
 		return err
 	}
-
+	storeNodeInCache(*node)
 	_, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal)
 	if err != nil {
 		logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error())

+ 3 - 28
logic/relay.go

@@ -1,12 +1,10 @@
 package logic
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"net"
 
-	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 )
@@ -33,25 +31,11 @@ func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error)
 	node.IsRelay = true
 	node.RelayedNodes = relay.RelayedNodes
 	node.SetLastModified()
-	nodeData, err := json.Marshal(&node)
+	err = UpsertNode(&node)
 	if err != nil {
 		return returnnodes, node, err
 	}
-	if err = database.Insert(node.ID.String(), string(nodeData), database.NODES_TABLE_NAME); err != nil {
-		return returnnodes, models.Node{}, err
-	}
 	returnnodes = SetRelayedNodes(true, relay.NodeID, relay.RelayedNodes)
-	for _, relayedNode := range returnnodes {
-		data, err := json.Marshal(&relayedNode)
-		if err != nil {
-			logger.Log(0, "marshalling relayed node", err.Error())
-			continue
-		}
-		if err := database.Insert(relayedNode.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil {
-			logger.Log(0, "inserting relayed node", err.Error())
-			continue
-		}
-	}
 	return returnnodes, node, nil
 }
 
@@ -71,12 +55,7 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N
 			node.RelayedBy = ""
 		}
 		node.SetLastModified()
-		data, err := json.Marshal(&node)
-		if err != nil {
-			logger.Log(0, "setRelayedNodes.Marshal", err.Error())
-			continue
-		}
-		if err := database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil {
+		if err := UpsertNode(&node); err != nil {
 			logger.Log(0, "setRelayedNodes.Insert", err.Error())
 			continue
 		}
@@ -145,11 +124,7 @@ func DeleteRelay(network, nodeid string) ([]models.Node, models.Node, error) {
 	node.IsRelay = false
 	node.RelayedNodes = []string{}
 	node.SetLastModified()
-	data, err := json.Marshal(&node)
-	if err != nil {
-		return returnnodes, models.Node{}, err
-	}
-	if err = database.Insert(nodeid, string(data), database.NODES_TABLE_NAME); err != nil {
+	if err = UpsertNode(&node); err != nil {
 		return returnnodes, models.Node{}, err
 	}
 	return returnnodes, node, nil

+ 1 - 1
mq/mq.go

@@ -80,7 +80,7 @@ func SetupMQTT() {
 			logger.Log(0, "node metrics subscription failed")
 		}
 
-		opts.SetOrderMatters(true)
+		opts.SetOrderMatters(false)
 		opts.SetResumeSubs(true)
 	})
 	mqclient = mqtt.NewClient(opts)