Browse Source

NET-735: HA Support (#2701)

* cache enabled option, cache hosts data if only enabled

* cache nodes only when enabled

* cache extclients only when enabled

* cache acls only when enabled
Abhishek K 1 year ago
parent
commit
2c4a27c53b

+ 1 - 0
config/config.go

@@ -91,6 +91,7 @@ type ServerConfig struct {
 	Environment                string        `yaml:"environment"`
 	JwtValidityDuration        time.Duration `yaml:"jwt_validity_duration"`
 	RacAutoDisable             bool          `yaml:"rac_auto_disable"`
+	CacheEnabled               bool          `yaml:"caching_enabled"`
 }
 
 // SQLConfig - Generic SQL Config

+ 4 - 1
controllers/node_test.go

@@ -10,6 +10,7 @@ import (
 	"github.com/gravitl/netmaker/logic/acls"
 	"github.com/gravitl/netmaker/logic/acls/nodeacls"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
 	"github.com/stretchr/testify/assert"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
@@ -217,7 +218,9 @@ func TestNodeACLs(t *testing.T) {
 }
 
 func deleteAllNodes() {
-	logic.ClearNodeCache()
+	if servercfg.CacheEnabled() {
+		logic.ClearNodeCache()
+	}
 	database.DeleteAllRecords(database.NODES_TABLE_NAME)
 }
 

+ 11 - 4
logic/acls/common.go

@@ -5,6 +5,7 @@ import (
 	"sync"
 
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/exp/slog"
 )
 
@@ -128,8 +129,10 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err
 func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
 	aclMutex.RLock()
 	defer aclMutex.RUnlock()
-	if aclContainer, ok := fetchAclContainerFromCache(containerID); ok {
-		return aclContainer, nil
+	if servercfg.CacheEnabled() {
+		if aclContainer, ok := fetchAclContainerFromCache(containerID); ok {
+			return aclContainer, nil
+		}
 	}
 	aclJson, err := fetchACLContainerJson(ContainerID(containerID))
 	if err != nil {
@@ -139,7 +142,9 @@ func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
 	if err := json.Unmarshal([]byte(aclJson), &currentNetworkACL); err != nil {
 		return nil, err
 	}
-	storeAclContainerInCache(containerID, currentNetworkACL)
+	if servercfg.CacheEnabled() {
+		storeAclContainerInCache(containerID, currentNetworkACL)
+	}
 	return currentNetworkACL, nil
 }
 
@@ -176,7 +181,9 @@ func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACL
 	if err != nil {
 		return aclContainer, err
 	}
-	storeAclContainerInCache(containerID, aclContainer)
+	if servercfg.CacheEnabled() {
+		storeAclContainerInCache(containerID, aclContainer)
+	}
 	return aclContainer, nil
 }
 

+ 4 - 1
logic/acls/nodeacls/modify.go

@@ -3,6 +3,7 @@ package nodeacls
 import (
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logic/acls"
+	"github.com/gravitl/netmaker/servercfg"
 )
 
 // CreateNodeACL - inserts or updates a node ACL on given network and adds to state
@@ -87,6 +88,8 @@ func DeleteACLContainer(network NetworkID) error {
 	if err != nil {
 		return err
 	}
-	acls.DeleteAclFromCache(acls.ContainerID(network))
+	if servercfg.CacheEnabled() {
+		acls.DeleteAclFromCache(acls.ContainerID(network))
+	}
 	return nil
 }

+ 25 - 12
logic/extpeers.go

@@ -11,6 +11,7 @@ import (
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/exp/slog"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
@@ -80,21 +81,25 @@ func DeleteExtClient(network string, clientid string) error {
 	if err != nil {
 		return err
 	}
-	deleteExtClientFromCache(key)
+	if servercfg.CacheEnabled() {
+		deleteExtClientFromCache(key)
+	}
 	return nil
 }
 
 // GetNetworkExtClients - gets the ext clients of given network
 func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
 	var extclients []models.ExtClient
-	allextclients := getAllExtClientsFromCache()
-	if len(allextclients) != 0 {
-		for _, extclient := range allextclients {
-			if extclient.Network == network {
-				extclients = append(extclients, extclient)
+	if servercfg.CacheEnabled() {
+		allextclients := getAllExtClientsFromCache()
+		if len(allextclients) != 0 {
+			for _, extclient := range allextclients {
+				if extclient.Network == network {
+					extclients = append(extclients, extclient)
+				}
 			}
+			return extclients, nil
 		}
-		return extclients, nil
 	}
 	records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME)
 	if err != nil {
@@ -111,7 +116,9 @@ func GetNetworkExtClients(network string) ([]models.ExtClient, error) {
 		}
 		key, err := GetRecordKey(extclient.ClientID, extclient.Network)
 		if err == nil {
-			storeExtClientInCache(key, extclient)
+			if servercfg.CacheEnabled() {
+				storeExtClientInCache(key, extclient)
+			}
 		}
 		if extclient.Network == network {
 			extclients = append(extclients, extclient)
@@ -127,15 +134,19 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) {
 	if err != nil {
 		return extclient, err
 	}
-	if extclient, ok := getExtClientFromCache(key); ok {
-		return extclient, nil
+	if servercfg.CacheEnabled() {
+		if extclient, ok := getExtClientFromCache(key); ok {
+			return extclient, nil
+		}
 	}
 	data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
 	if err != nil {
 		return extclient, err
 	}
 	err = json.Unmarshal([]byte(data), &extclient)
-	storeExtClientInCache(key, extclient)
+	if servercfg.CacheEnabled() {
+		storeExtClientInCache(key, extclient)
+	}
 	return extclient, err
 }
 
@@ -235,7 +246,9 @@ func SaveExtClient(extclient *models.ExtClient) error {
 	if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil {
 		return err
 	}
-	storeExtClientInCache(key, *extclient)
+	if servercfg.CacheEnabled() {
+		storeExtClientInCache(key, *extclient)
+	}
 	return SetNetworkNodesLastModified(extclient.Network)
 }
 

+ 35 - 14
logic/hosts.go

@@ -81,16 +81,21 @@ const (
 
 // GetAllHosts - returns all hosts in flat list or error
 func GetAllHosts() ([]models.Host, error) {
-	currHosts := getHostsFromCache()
-	if len(currHosts) != 0 {
-		return currHosts, nil
+	var currHosts []models.Host
+	if servercfg.CacheEnabled() {
+		currHosts := getHostsFromCache()
+		if len(currHosts) != 0 {
+			return currHosts, nil
+		}
 	}
 	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 		return nil, err
 	}
 	currHostsMap := make(map[string]models.Host)
-	defer loadHostsIntoCache(currHostsMap)
+	if servercfg.CacheEnabled() {
+		defer loadHostsIntoCache(currHostsMap)
+	}
 	for k := range records {
 		var h models.Host
 		err = json.Unmarshal([]byte(records[k]), &h)
@@ -116,16 +121,20 @@ func GetAllHostsAPI(hosts []models.Host) []models.ApiHost {
 
 // GetHostsMap - gets all the current hosts on machine in a map
 func GetHostsMap() (map[string]models.Host, error) {
-	hostsMap := getHostsMapFromCache()
-	if len(hostsMap) != 0 {
-		return hostsMap, nil
+	if servercfg.CacheEnabled() {
+		hostsMap := getHostsMapFromCache()
+		if len(hostsMap) != 0 {
+			return hostsMap, nil
+		}
 	}
 	records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
 	if err != nil && !database.IsEmptyRecord(err) {
 		return nil, err
 	}
 	currHostMap := make(map[string]models.Host)
-	defer loadHostsIntoCache(currHostMap)
+	if servercfg.CacheEnabled() {
+		defer loadHostsIntoCache(currHostMap)
+	}
 	for k := range records {
 		var h models.Host
 		err = json.Unmarshal([]byte(records[k]), &h)
@@ -140,8 +149,10 @@ func GetHostsMap() (map[string]models.Host, error) {
 
 // GetHost - gets a host from db given id
 func GetHost(hostid string) (*models.Host, error) {
-	if host, ok := getHostFromCache(hostid); ok {
-		return &host, nil
+	if servercfg.CacheEnabled() {
+		if host, ok := getHostFromCache(hostid); ok {
+			return &host, nil
+		}
 	}
 	record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid)
 	if err != nil {
@@ -152,7 +163,10 @@ func GetHost(hostid string) (*models.Host, error) {
 	if err = json.Unmarshal([]byte(record), &h); err != nil {
 		return nil, err
 	}
-	storeHostInCache(h)
+	if servercfg.CacheEnabled() {
+		storeHostInCache(h)
+	}
+
 	return &h, nil
 }
 
@@ -279,7 +293,10 @@ func UpsertHost(h *models.Host) error {
 	if err != nil {
 		return err
 	}
-	storeHostInCache(*h)
+	if servercfg.CacheEnabled() {
+		storeHostInCache(*h)
+	}
+
 	return nil
 }
 
@@ -303,8 +320,10 @@ func RemoveHost(h *models.Host, forceDelete bool) error {
 	if err != nil {
 		return err
 	}
+	if servercfg.CacheEnabled() {
+		deleteHostFromCache(h.ID.String())
+	}
 
-	deleteHostFromCache(h.ID.String())
 	return nil
 }
 
@@ -318,7 +337,9 @@ func RemoveHostByID(hostID string) error {
 	if err != nil {
 		return err
 	}
-	deleteHostFromCache(hostID)
+	if servercfg.CacheEnabled() {
+		deleteHostFromCache(hostID)
+	}
 	return nil
 }
 

+ 30 - 12
logic/nodes.go

@@ -119,7 +119,9 @@ func UpdateNodeCheckin(node *models.Node) error {
 	if err != nil {
 		return err
 	}
-	storeNodeInCache(*node)
+	if servercfg.CacheEnabled() {
+		storeNodeInCache(*node)
+	}
 	return nil
 }
 
@@ -134,7 +136,9 @@ func UpsertNode(newNode *models.Node) error {
 	if err != nil {
 		return err
 	}
-	storeNodeInCache(*newNode)
+	if servercfg.CacheEnabled() {
+		storeNodeInCache(*newNode)
+	}
 	return nil
 }
 
@@ -171,7 +175,9 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
 			if err != nil {
 				return err
 			}
-			storeNodeInCache(*newNode)
+			if servercfg.CacheEnabled() {
+				storeNodeInCache(*newNode)
+			}
 			return nil
 		}
 	}
@@ -264,7 +270,9 @@ func DeleteNodeByID(node *models.Node) error {
 			return err
 		}
 	}
-	deleteNodeFromCache(node.ID.String())
+	if servercfg.CacheEnabled() {
+		deleteNodeFromCache(node.ID.String())
+	}
 	if servercfg.IsDNSMode() {
 		SetDNS()
 	}
@@ -310,12 +318,16 @@ func ValidateNode(node *models.Node, isUpdate bool) error {
 // GetAllNodes - returns all nodes in the DB
 func GetAllNodes() ([]models.Node, error) {
 	var nodes []models.Node
-	nodes = getNodesFromCache()
-	if len(nodes) != 0 {
-		return nodes, nil
+	if servercfg.CacheEnabled() {
+		nodes = getNodesFromCache()
+		if len(nodes) != 0 {
+			return nodes, nil
+		}
 	}
 	nodesMap := make(map[string]models.Node)
-	defer loadNodesIntoCache(nodesMap)
+	if servercfg.CacheEnabled() {
+		defer loadNodesIntoCache(nodesMap)
+	}
 	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
 	if err != nil {
 		if database.IsEmptyRecord(err) {
@@ -389,8 +401,10 @@ func GetRecordKey(id string, network string) (string, error) {
 }
 
 func GetNodeByID(uuid string) (models.Node, error) {
-	if node, ok := getNodeFromCache(uuid); ok {
-		return node, nil
+	if servercfg.CacheEnabled() {
+		if node, ok := getNodeFromCache(uuid); ok {
+			return node, nil
+		}
 	}
 	var record, err = database.FetchRecord(database.NODES_TABLE_NAME, uuid)
 	if err != nil {
@@ -400,7 +414,9 @@ func GetNodeByID(uuid string) (models.Node, error) {
 	if err = json.Unmarshal([]byte(record), &node); err != nil {
 		return models.Node{}, err
 	}
-	storeNodeInCache(node)
+	if servercfg.CacheEnabled() {
+		storeNodeInCache(node)
+	}
 	return node, nil
 }
 
@@ -556,7 +572,9 @@ func createNode(node *models.Node) error {
 	if err != nil {
 		return err
 	}
-	storeNodeInCache(*node)
+	if servercfg.CacheEnabled() {
+		storeNodeInCache(*node)
+	}
 	_, err = nodeacls.CreateNodeACL(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), defaultACLVal)
 	if err != nil {
 		logger.Log(1, "failed to create node ACL for node,", node.ID.String(), "err:", err.Error())

+ 2 - 0
scripts/netmaker.default.env

@@ -81,3 +81,5 @@ OIDC_ISSUER=
 JWT_VALIDITY_DURATION=43200
 # Auto disable a user's connecteds clients bassed on JWT token expiration
 RAC_AUTO_DISABLE="true"
+# if turned on data will be cached on to improve performance significantly (IMPORTANT: If HA set to `false` )
+CACHING_ENABLED="true

+ 11 - 0
servercfg/serverconf.go

@@ -207,6 +207,17 @@ func GetDB() string {
 	return database
 }
 
+// CacheEnabled - checks if cache is enabled
+func CacheEnabled() bool {
+	caching := false
+	if os.Getenv("CACHING_ENABLED") != "" {
+		caching = os.Getenv("CACHING_ENABLED") == "true"
+	} else if config.Config.Server.Database != "" {
+		caching = config.Config.Server.CacheEnabled
+	}
+	return caching
+}
+
 // GetAPIHost - gets the api host
 func GetAPIHost() string {
 	serverhost := "127.0.0.1"