Browse Source

NET-969: add additional acl mutex (#2827)

* add additional acl mutex

* fix acls race issue

* rm defer mutex
Abhishek K 1 year ago
parent
commit
91bfcba8e2
3 changed files with 37 additions and 20 deletions
  1. 17 17
      logic/acls/common.go
  2. 7 1
      logic/acls/nodeacls/modify.go
  3. 13 2
      logic/acls/nodeacls/retrieve.go

+ 17 - 17
logic/acls/common.go

@@ -12,7 +12,7 @@ import (
 var (
 	aclCacheMutex = &sync.RWMutex{}
 	aclCacheMap   = make(map[ContainerID]ACLContainer)
-	aclMutex      = &sync.RWMutex{}
+	AclMutex      = &sync.RWMutex{}
 )
 
 func fetchAclContainerFromCache(containerID ContainerID) (aclCont ACLContainer, ok bool) {
@@ -38,22 +38,22 @@ func DeleteAclFromCache(containerID ContainerID) {
 
 // ACL.Allow - allows access by ID in memory
 func (acl ACL) Allow(ID AclID) {
-	aclMutex.Lock()
-	defer aclMutex.Unlock()
+	AclMutex.Lock()
+	defer AclMutex.Unlock()
 	acl[ID] = Allowed
 }
 
 // ACL.DisallowNode - disallows access by ID in memory
 func (acl ACL) Disallow(ID AclID) {
-	aclMutex.Lock()
-	defer aclMutex.Unlock()
+	AclMutex.Lock()
+	defer AclMutex.Unlock()
 	acl[ID] = NotAllowed
 }
 
 // ACL.Remove - removes a node from a ACL in memory
 func (acl ACL) Remove(ID AclID) {
-	aclMutex.Lock()
-	defer aclMutex.Unlock()
+	AclMutex.Lock()
+	defer AclMutex.Unlock()
 	delete(acl, ID)
 }
 
@@ -64,24 +64,24 @@ func (acl ACL) Save(containerID ContainerID, ID AclID) (ACL, error) {
 
 // ACL.IsAllowed - sees if ID is allowed in referring ACL
 func (acl ACL) IsAllowed(ID AclID) (allowed bool) {
-	aclMutex.RLock()
+	AclMutex.RLock()
 	allowed = acl[ID] == Allowed
-	aclMutex.RUnlock()
+	AclMutex.RUnlock()
 	return
 }
 
 // ACLContainer.UpdateACL - saves the state of a ACL in the ACLContainer in memory
 func (aclContainer ACLContainer) UpdateACL(ID AclID, acl ACL) ACLContainer {
-	aclMutex.Lock()
-	defer aclMutex.Unlock()
+	AclMutex.Lock()
+	defer AclMutex.Unlock()
 	aclContainer[ID] = acl
 	return aclContainer
 }
 
 // ACLContainer.RemoveACL - removes the state of a ACL in the ACLContainer in memory
 func (aclContainer ACLContainer) RemoveACL(ID AclID) ACLContainer {
-	aclMutex.Lock()
-	defer aclMutex.Unlock()
+	AclMutex.Lock()
+	defer AclMutex.Unlock()
 	delete(aclContainer, ID)
 	return aclContainer
 }
@@ -127,8 +127,8 @@ func (aclContainer ACLContainer) Get(containerID ContainerID) (ACLContainer, err
 
 // fetchACLContainer - fetches all current rules in given ACL container
 func fetchACLContainer(containerID ContainerID) (ACLContainer, error) {
-	aclMutex.RLock()
-	defer aclMutex.RUnlock()
+	AclMutex.RLock()
+	defer AclMutex.RUnlock()
 	if servercfg.CacheEnabled() {
 		if aclContainer, ok := fetchAclContainerFromCache(containerID); ok {
 			return aclContainer, nil
@@ -171,8 +171,8 @@ func upsertACL(containerID ContainerID, ID AclID, acl ACL) (ACL, error) {
 // upsertACLContainer - Inserts or updates a network ACL given the json string of the ACL and the container ID
 // if nil, create it
 func upsertACLContainer(containerID ContainerID, aclContainer ACLContainer) (ACLContainer, error) {
-	aclMutex.Lock()
-	defer aclMutex.Unlock()
+	AclMutex.Lock()
+	defer AclMutex.Unlock()
 	if aclContainer == nil {
 		aclContainer = make(ACLContainer)
 	}

+ 7 - 1
logic/acls/nodeacls/modify.go

@@ -22,12 +22,14 @@ func CreateNodeACL(networkID NetworkID, nodeID NodeID, defaultVal byte) (acls.AC
 			return nil, err
 		}
 	}
+	acls.AclMutex.Lock()
 	var newNodeACL = make(acls.ACL)
 	for existingNodeID := range currentNetworkACL {
 		currentNetworkACL[existingNodeID][acls.AclID(nodeID)] = defaultVal // set the old nodes to default value for new node
 		newNodeACL[existingNodeID] = defaultVal                            // set the old nodes in new node ACL to default value
 	}
-	currentNetworkACL[acls.AclID(nodeID)] = newNodeACL                        // append the new node's ACL
+	currentNetworkACL[acls.AclID(nodeID)] = newNodeACL // append the new node's ACL
+	acls.AclMutex.Unlock()
 	retNetworkACL, err := currentNetworkACL.Save(acls.ContainerID(networkID)) // insert into db
 	if err != nil {
 		return nil, err
@@ -63,7 +65,9 @@ func UpdateNodeACL(networkID NetworkID, nodeID NodeID, acl acls.ACL) (acls.ACL,
 	if err != nil {
 		return nil, err
 	}
+	acls.AclMutex.Lock()
 	currentNetworkACL[acls.AclID(nodeID)] = acl
+	acls.AclMutex.Unlock()
 	return currentNetworkACL[acls.AclID(nodeID)].Save(acls.ContainerID(networkID), acls.AclID(nodeID))
 }
 
@@ -73,12 +77,14 @@ func RemoveNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACLContainer, error
 	if err != nil {
 		return nil, err
 	}
+	acls.AclMutex.Lock()
 	for currentNodeID := range currentNetworkACL {
 		if NodeID(currentNodeID) != nodeID {
 			currentNetworkACL[currentNodeID].Remove(acls.AclID(nodeID))
 		}
 	}
 	delete(currentNetworkACL, acls.AclID(nodeID))
+	acls.AclMutex.Unlock()
 	return currentNetworkACL.Save(acls.ContainerID(networkID))
 }
 

+ 13 - 2
logic/acls/nodeacls/retrieve.go

@@ -13,7 +13,11 @@ func AreNodesAllowed(networkID NetworkID, node1, node2 NodeID) bool {
 	if err != nil {
 		return false
 	}
-	return currentNetworkACL[acls.AclID(node1)].IsAllowed(acls.AclID(node2)) && currentNetworkACL[acls.AclID(node2)].IsAllowed(acls.AclID(node1))
+	var allowed bool
+	acls.AclMutex.RLock()
+	allowed = currentNetworkACL[acls.AclID(node1)].IsAllowed(acls.AclID(node2)) && currentNetworkACL[acls.AclID(node2)].IsAllowed(acls.AclID(node1))
+	acls.AclMutex.RUnlock()
+	return allowed
 }
 
 // FetchNodeACL - fetches a specific node's ACL in a given network
@@ -22,10 +26,15 @@ func FetchNodeACL(networkID NetworkID, nodeID NodeID) (acls.ACL, error) {
 	if err != nil {
 		return nil, err
 	}
+	var acl acls.ACL
+	acls.AclMutex.RLock()
 	if currentNetworkACL[acls.AclID(nodeID)] == nil {
+		acls.AclMutex.RUnlock()
 		return nil, fmt.Errorf("no node ACL present for node %s", nodeID)
 	}
-	return currentNetworkACL[acls.AclID(nodeID)], nil
+	acl = currentNetworkACL[acls.AclID(nodeID)]
+	acls.AclMutex.RUnlock()
+	return acl, nil
 }
 
 // FetchNodeACLJson - fetches a node's acl in given network except returns the json string
@@ -34,6 +43,8 @@ func FetchNodeACLJson(networkID NetworkID, nodeID NodeID) (acls.ACLJson, error)
 	if err != nil {
 		return "", err
 	}
+	acls.AclMutex.RLock()
+	defer acls.AclMutex.RUnlock()
 	jsonData, err := json.Marshal(&currentNodeACL)
 	if err != nil {
 		return "", err