Browse Source

enforce new acl policy access check

abhishek9686 11 months ago
parent
commit
fcd3325173
3 changed files with 110 additions and 23 deletions
  1. 101 21
      logic/acls.go
  2. 1 0
      logic/peers.go
  3. 8 2
      models/acl.go

+ 101 - 21
logic/acls.go

@@ -2,8 +2,8 @@ package logic
 
 import (
 	"encoding/json"
+	"errors"
 	"sort"
-	"strings"
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
@@ -38,22 +38,22 @@ func IsAclPolicyValid(acl models.Acl) bool {
 	case models.UserPolicy:
 		// src list should only contain users
 		for _, srcI := range acl.Src {
-			userTagLi := strings.Split(srcI, ":")
-			if len(userTagLi) < 2 {
+
+			if srcI.ID == "" || srcI.Value == "" {
 				break
 			}
-			if userTagLi[0] != models.UserAclID.String() &&
-				userTagLi[0] != models.UserGroupAclID.String() {
+			if srcI.ID != models.UserAclID &&
+				srcI.ID != models.UserGroupAclID {
 				break
 			}
 			// check if user group is valid
-			if userTagLi[0] == models.UserAclID.String() {
-				_, err := GetUser(userTagLi[1])
+			if srcI.ID == models.UserAclID {
+				_, err := GetUser(srcI.Value)
 				if err != nil {
 					break
 				}
-			} else if userTagLi[0] == models.UserGroupAclID.String() {
-				err := IsGroupValid(models.UserGroupID(userTagLi[1]))
+			} else if srcI.ID == models.UserGroupAclID {
+				err := IsGroupValid(models.UserGroupID(srcI.Value))
 				if err != nil {
 					break
 				}
@@ -61,19 +61,19 @@ func IsAclPolicyValid(acl models.Acl) bool {
 
 		}
 		for _, dstI := range acl.Dst {
-			dstILi := strings.Split(dstI, ":")
-			if len(dstILi) < 2 {
+
+			if dstI.ID == "" || dstI.Value == "" {
 				break
 			}
-			if dstILi[0] == models.UserAclID.String() ||
-				dstILi[0] == models.UserGroupAclID.String() {
+			if dstI.ID == models.UserAclID ||
+				dstI.ID == models.UserGroupAclID {
 				break
 			}
-			if dstILi[0] != models.DeviceAclID.String() {
+			if dstI.ID != models.DeviceAclID {
 				break
 			}
 			// check if tag is valid
-			_, err := GetTag(models.TagID(dstILi[1]))
+			_, err := GetTag(models.TagID(dstI.Value))
 			if err != nil {
 				break
 			}
@@ -81,20 +81,29 @@ func IsAclPolicyValid(acl models.Acl) bool {
 		isValid = true
 	case models.DevicePolicy:
 		for _, srcI := range acl.Src {
-			deviceTagLi := strings.Split(srcI, ":")
-			if len(deviceTagLi) < 2 {
+			if srcI.ID == "" || srcI.Value == "" {
+				break
+			}
+			if srcI.ID != models.DeviceAclID {
 				break
 			}
-			if deviceTagLi[0] != models.DeviceAclID.String() {
+			// check if tag is valid
+			_, err := GetTag(models.TagID(srcI.Value))
+			if err != nil {
 				break
 			}
 		}
 		for _, dstI := range acl.Dst {
-			deviceTagLi := strings.Split(dstI, ":")
-			if len(deviceTagLi) < 2 {
+
+			if dstI.ID == "" || dstI.Value == "" {
+				break
+			}
+			if dstI.ID != models.DeviceAclID {
 				break
 			}
-			if deviceTagLi[0] != models.DeviceAclID.String() {
+			// check if tag is valid
+			_, err := GetTag(models.TagID(dstI.Value))
+			if err != nil {
 				break
 			}
 		}
@@ -124,6 +133,36 @@ func DeleteAcl(a models.Acl) error {
 	return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID.String())
 }
 
+func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
+	acls, _ := ListAcls(netID)
+	for _, acl := range acls {
+		if acl.Default && acl.RuleType == ruleType {
+			return acl, nil
+		}
+	}
+	return models.Acl{}, errors.New("default rule not found")
+}
+
+// listDevicePolicies - lists all device policies in a network
+func listDevicePolicies(netID models.NetworkID) []models.Acl {
+	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
+	if err != nil && !database.IsEmptyRecord(err) {
+		return []models.Acl{}
+	}
+	acls := []models.Acl{}
+	for _, dataI := range data {
+		acl := models.Acl{}
+		err := json.Unmarshal([]byte(dataI), &acl)
+		if err != nil {
+			continue
+		}
+		if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
+			acls = append(acls, acl)
+		}
+	}
+	return acls
+}
+
 // ListAcls - lists all acl policies
 func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
 	data, err := database.FetchRecords(database.TAG_TABLE_NAME)
@@ -144,6 +183,47 @@ func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
 	return acls, nil
 }
 
+func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
+	aclValueMap := make(map[string]struct{})
+	for _, aclTagI := range acltags {
+		aclValueMap[aclTagI.ID.String()] = struct{}{}
+	}
+	return aclValueMap
+}
+
+func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
+	// check default policy if all allowed return true
+	defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+	if err == nil {
+		if defaultPolicy.Enabled {
+			return true
+		}
+	}
+	// list device policies
+	policies := listDevicePolicies(models.NetworkID(peer.Network))
+	for _, policy := range policies {
+		srcMap := convAclTagToValueMap(policy.Src)
+		dstMap := convAclTagToValueMap(policy.Dst)
+		for tagID := range peer.Tags {
+			if _, ok := dstMap[tagID.String()]; ok {
+				for tagID := range node.Tags {
+					if _, ok := srcMap[tagID.String()]; ok {
+						return true
+					}
+				}
+			}
+			if _, ok := srcMap[tagID.String()]; ok {
+				for tagID := range node.Tags {
+					if _, ok := dstMap[tagID.String()]; ok {
+						return true
+					}
+				}
+			}
+		}
+	}
+	return false
+}
+
 // SortTagEntrys - Sorts slice of Tag entries by their id
 func SortAclEntrys(acls []models.Acl) {
 	sort.Slice(acls, func(i, j int) bool {

+ 1 - 0
logic/peers.go

@@ -241,6 +241,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				!peer.PendingDelete &&
 				peer.Connected &&
 				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
+				IsNodeAllowedToCommunicate(node, peer) &&
 				(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
 				peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
 			}

+ 8 - 2
models/acl.go

@@ -23,6 +23,11 @@ const (
 	DevicePolicy AclPolicyType = "device-policy"
 )
 
+type AclPolicyTag struct {
+	ID    AclGroupType `json:"id"`
+	Value string       `json:"value"`
+}
+
 type AclGroupType string
 
 const (
@@ -39,11 +44,12 @@ func (g AclGroupType) String() string {
 
 type Acl struct {
 	ID               uuid.UUID               `json:"id"`
+	Default          bool                    `json:"default"`
 	Name             string                  `json:"name"`
 	NetworkID        NetworkID               `json:"network_id"`
 	RuleType         AclPolicyType           `json:"policy_type"`
-	Src              []string                `json:"src_type"`
-	Dst              []string                `json:"dst_type"`
+	Src              []AclPolicyTag          `json:"src_type"`
+	Dst              []AclPolicyTag          `json:"dst_type"`
 	AllowedDirection AllowedTrafficDirection `json:"allowed_traffic_direction"`
 	Enabled          bool                    `json:"enabled"`
 	CreatedBy        string                  `json:"created_by"`