Преглед на файлове

Merge pull request #3340 from gravitl/NET-1964

NET-1964: add node mutex to model
Abhishek K преди 6 месеца
родител
ревизия
2f0d289813
променени са 5 файла, в които са добавени 118 реда и са изтрити 34 реда
  1. 52 12
      logic/acls.go
  2. 9 0
      logic/extpeers.go
  3. 51 22
      logic/nodes.go
  4. 4 0
      models/extclient.go
  5. 2 0
      models/node.go

+ 52 - 12
logic/acls.go

@@ -17,7 +17,6 @@ import (
 var (
 	aclCacheMutex = &sync.RWMutex{}
 	aclCacheMap   = make(map[string]models.Acl)
-	aclTagsMutex  = &sync.RWMutex{}
 )
 
 func MigrateAclPolicies() {
@@ -576,10 +575,22 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 	if peer.IsStatic {
 		peer = peer.StaticNode.ConvertToStaticNode()
 	}
-	aclTagsMutex.RLock()
-	peerTags := maps.Clone(peer.Tags)
-	nodeTags := maps.Clone(node.Tags)
-	aclTagsMutex.RUnlock()
+	var nodeTags, peerTags map[models.TagID]struct{}
+	if node.Mutex != nil {
+		node.Mutex.Lock()
+		nodeTags = maps.Clone(node.Tags)
+		node.Mutex.Unlock()
+	} else {
+		nodeTags = node.Tags
+	}
+	if peer.Mutex != nil {
+		peer.Mutex.Lock()
+		peerTags = maps.Clone(peer.Tags)
+		peer.Mutex.Unlock()
+	} else {
+		peerTags = peer.Tags
+	}
+
 	if checkDefaultPolicy {
 		// check default policy if all allowed return true
 		defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@@ -661,10 +672,21 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 	if peer.IsStatic {
 		peer = peer.StaticNode.ConvertToStaticNode()
 	}
-	aclTagsMutex.RLock()
-	peerTags := maps.Clone(peer.Tags)
-	nodeTags := maps.Clone(node.Tags)
-	aclTagsMutex.RUnlock()
+	var nodeTags, peerTags map[models.TagID]struct{}
+	if node.Mutex != nil {
+		node.Mutex.Lock()
+		nodeTags = maps.Clone(node.Tags)
+		node.Mutex.Unlock()
+	} else {
+		nodeTags = node.Tags
+	}
+	if peer.Mutex != nil {
+		peer.Mutex.Lock()
+		peerTags = maps.Clone(peer.Tags)
+		peer.Mutex.Unlock()
+	} else {
+		peerTags = peer.Tags
+	}
 	if checkDefaultPolicy {
 		// check default policy if all allowed return true
 		defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@@ -862,7 +884,15 @@ func getUserAclRulesForNode(targetnode *models.Node,
 	userGrpMap := GetUserGrpMap()
 	allowedUsers := make(map[string][]models.Acl)
 	acls := listUserPolicies(models.NetworkID(targetnode.Network))
-	for nodeTag := range targetnode.Tags {
+	var targetNodeTags = make(map[models.TagID]struct{})
+	if targetnode.Mutex != nil {
+		targetnode.Mutex.Lock()
+		targetNodeTags = maps.Clone(targetnode.Tags)
+		targetnode.Mutex.Unlock()
+	} else {
+		targetNodeTags = maps.Clone(targetnode.Tags)
+	}
+	for nodeTag := range targetNodeTags {
 		for _, acl := range acls {
 			if !acl.Enabled {
 				continue
@@ -886,6 +916,7 @@ func getUserAclRulesForNode(targetnode *models.Node,
 			}
 		}
 	}
+
 	for _, userNode := range userNodes {
 		if !userNode.StaticNode.Enabled {
 			continue
@@ -942,8 +973,17 @@ func GetAclRulesForNode(targetnode *models.Node) (rules map[string]models.AclRul
 	}
 
 	acls := listDevicePolicies(models.NetworkID(targetnode.Network))
-	targetnode.Tags["*"] = struct{}{}
-	for nodeTag := range targetnode.Tags {
+
+	var targetNodeTags = make(map[models.TagID]struct{})
+	if targetnode.Mutex != nil {
+		targetnode.Mutex.Lock()
+		targetNodeTags = maps.Clone(targetnode.Tags)
+		targetnode.Mutex.Unlock()
+	} else {
+		targetNodeTags = maps.Clone(targetnode.Tags)
+	}
+	targetNodeTags["*"] = struct{}{}
+	for nodeTag := range targetNodeTags {
 		for _, acl := range acls {
 			if !acl.Enabled {
 				continue

+ 9 - 0
logic/extpeers.go

@@ -28,6 +28,9 @@ var (
 func getAllExtClientsFromCache() (extClients []models.ExtClient) {
 	extClientCacheMutex.RLock()
 	for _, extclient := range extClientCacheMap {
+		if extclient.Mutex == nil {
+			extclient.Mutex = &sync.Mutex{}
+		}
 		extClients = append(extClients, extclient)
 	}
 	extClientCacheMutex.RUnlock()
@@ -43,12 +46,18 @@ func deleteExtClientFromCache(key string) {
 func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) {
 	extClientCacheMutex.RLock()
 	extclient, ok = extClientCacheMap[key]
+	if extclient.Mutex == nil {
+		extclient.Mutex = &sync.Mutex{}
+	}
 	extClientCacheMutex.RUnlock()
 	return
 }
 
 func storeExtClientInCache(key string, extclient models.ExtClient) {
 	extClientCacheMutex.Lock()
+	if extclient.Mutex == nil {
+		extclient.Mutex = &sync.Mutex{}
+	}
 	extClientCacheMap[key] = extclient
 	extClientCacheMutex.Unlock()
 }

+ 51 - 22
logic/nodes.go

@@ -35,12 +35,18 @@ var (
 func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
 	nodeCacheMutex.RLock()
 	node, ok = nodesCacheMap[nodeID]
+	if node.Mutex == nil {
+		node.Mutex = &sync.Mutex{}
+	}
 	nodeCacheMutex.RUnlock()
 	return
 }
 func getNodesFromCache() (nodes []models.Node) {
 	nodeCacheMutex.RLock()
 	for _, node := range nodesCacheMap {
+		if node.Mutex == nil {
+			node.Mutex = &sync.Mutex{}
+		}
 		nodes = append(nodes, node)
 	}
 	nodeCacheMutex.RUnlock()
@@ -425,6 +431,9 @@ func GetAllNodes() ([]models.Node, error) {
 		}
 		// add node to our array
 		nodes = append(nodes, node)
+		if node.Mutex == nil {
+			node.Mutex = &sync.Mutex{}
+		}
 		nodesMap[node.ID.String()] = node
 	}
 
@@ -811,9 +820,16 @@ func GetTagMapWithNodes() (tagNodesMap map[models.TagID][]models.Node) {
 		if nodeI.Tags == nil {
 			continue
 		}
+		if nodeI.Mutex != nil {
+			nodeI.Mutex.Lock()
+		}
 		for nodeTagID := range nodeI.Tags {
 			tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
 		}
+		if nodeI.Mutex != nil {
+			nodeI.Mutex.Unlock()
+		}
+
 	}
 	return
 }
@@ -825,9 +841,15 @@ func GetTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (
 		if nodeI.Tags == nil {
 			continue
 		}
+		if nodeI.Mutex != nil {
+			nodeI.Mutex.Lock()
+		}
 		for nodeTagID := range nodeI.Tags {
 			tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI)
 		}
+		if nodeI.Mutex != nil {
+			nodeI.Mutex.Unlock()
+		}
 	}
 	tagNodesMap["*"] = nodes
 	if !withStaticNodes {
@@ -846,17 +868,16 @@ func AddTagMapWithStaticNodes(netID models.NetworkID,
 		if extclient.Tags == nil || extclient.RemoteAccessClientID != "" {
 			continue
 		}
+		if extclient.Mutex != nil {
+			extclient.Mutex.Lock()
+		}
 		for tagID := range extclient.Tags {
-			tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{
-				IsStatic:   true,
-				StaticNode: extclient,
-			})
-			tagNodesMap["*"] = append(tagNodesMap["*"], models.Node{
-				IsStatic:   true,
-				StaticNode: extclient,
-			})
+			tagNodesMap[tagID] = append(tagNodesMap[tagID], extclient.ConvertToStaticNode())
+			tagNodesMap["*"] = append(tagNodesMap["*"], extclient.ConvertToStaticNode())
+		}
+		if extclient.Mutex != nil {
+			extclient.Mutex.Unlock()
 		}
-
 	}
 	return tagNodesMap
 }
@@ -871,11 +892,14 @@ func AddTagMapWithStaticNodesWithUsers(netID models.NetworkID,
 		if extclient.Tags == nil {
 			continue
 		}
+		if extclient.Mutex != nil {
+			extclient.Mutex.Lock()
+		}
 		for tagID := range extclient.Tags {
-			tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{
-				IsStatic:   true,
-				StaticNode: extclient,
-			})
+			tagNodesMap[tagID] = append(tagNodesMap[tagID], extclient.ConvertToStaticNode())
+		}
+		if extclient.Mutex != nil {
+			extclient.Mutex.Unlock()
 		}
 
 	}
@@ -893,9 +917,15 @@ func GetNodesWithTag(tagID models.TagID) map[string]models.Node {
 		if nodeI.Tags == nil {
 			continue
 		}
+		if nodeI.Mutex != nil {
+			nodeI.Mutex.Lock()
+		}
 		if _, ok := nodeI.Tags[tagID]; ok {
 			nMap[nodeI.ID.String()] = nodeI
 		}
+		if nodeI.Mutex != nil {
+			nodeI.Mutex.Unlock()
+		}
 	}
 	return AddStaticNodesWithTag(tag, nMap)
 }
@@ -909,13 +939,15 @@ func AddStaticNodesWithTag(tag models.Tag, nMap map[string]models.Node) map[stri
 		if extclient.RemoteAccessClientID != "" {
 			continue
 		}
+		if extclient.Mutex != nil {
+			extclient.Mutex.Lock()
+		}
 		if _, ok := extclient.Tags[tag.ID]; ok {
-			nMap[extclient.ClientID] = models.Node{
-				IsStatic:   true,
-				StaticNode: extclient,
-			}
+			nMap[extclient.ClientID] = extclient.ConvertToStaticNode()
+		}
+		if extclient.Mutex != nil {
+			extclient.Mutex.Unlock()
 		}
-
 	}
 	return nMap
 }
@@ -931,10 +963,7 @@ func GetStaticNodeWithTag(tagID models.TagID) map[string]models.Node {
 		return nMap
 	}
 	for _, extclient := range extclients {
-		nMap[extclient.ClientID] = models.Node{
-			IsStatic:   true,
-			StaticNode: extclient,
-		}
+		nMap[extclient.ClientID] = extclient.ConvertToStaticNode()
 	}
 	return nMap
 }

+ 4 - 0
models/extclient.go

@@ -1,5 +1,7 @@
 package models
 
+import "sync"
+
 // ExtClient - struct for external clients
 type ExtClient struct {
 	ClientID               string              `json:"clientid" bson:"clientid"`
@@ -25,6 +27,7 @@ type ExtClient struct {
 	DeviceName             string              `json:"device_name"`
 	PublicEndpoint         string              `json:"public_endpoint"`
 	Country                string              `json:"country"`
+	Mutex                  *sync.Mutex         `json:"-"`
 }
 
 // CustomExtClient - struct for CustomExtClient params
@@ -55,5 +58,6 @@ func (ext *ExtClient) ConvertToStaticNode() Node {
 		Tags:       ext.Tags,
 		IsStatic:   true,
 		StaticNode: *ext,
+		Mutex:      ext.Mutex,
 	}
 }

+ 2 - 0
models/node.go

@@ -5,6 +5,7 @@ import (
 	"math/rand"
 	"net"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/google/uuid"
@@ -117,6 +118,7 @@ type Node struct {
 	IsUserNode        bool                `json:"is_user_node"`
 	StaticNode        ExtClient           `json:"static_node"`
 	Status            NodeStatus          `json:"node_status"`
+	Mutex             *sync.Mutex         `json:"-"`
 }
 
 // LegacyNode - legacy struct for node model