|
@@ -2,8 +2,8 @@ package logic
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
+ "errors"
|
|
|
"sort"
|
|
|
- "strings"
|
|
|
|
|
|
"github.com/gravitl/netmaker/database"
|
|
|
"github.com/gravitl/netmaker/models"
|
|
@@ -38,22 +38,22 @@ func IsAclPolicyValid(acl models.Acl) bool {
|
|
|
case models.UserPolicy:
|
|
|
// src list should only contain users
|
|
|
for _, srcI := range acl.Src {
|
|
|
- userTagLi := strings.Split(srcI, ":")
|
|
|
- if len(userTagLi) < 2 {
|
|
|
+
|
|
|
+ if srcI.ID == "" || srcI.Value == "" {
|
|
|
break
|
|
|
}
|
|
|
- if userTagLi[0] != models.UserAclID.String() &&
|
|
|
- userTagLi[0] != models.UserGroupAclID.String() {
|
|
|
+ if srcI.ID != models.UserAclID &&
|
|
|
+ srcI.ID != models.UserGroupAclID {
|
|
|
break
|
|
|
}
|
|
|
// check if user group is valid
|
|
|
- if userTagLi[0] == models.UserAclID.String() {
|
|
|
- _, err := GetUser(userTagLi[1])
|
|
|
+ if srcI.ID == models.UserAclID {
|
|
|
+ _, err := GetUser(srcI.Value)
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
- } else if userTagLi[0] == models.UserGroupAclID.String() {
|
|
|
- err := IsGroupValid(models.UserGroupID(userTagLi[1]))
|
|
|
+ } else if srcI.ID == models.UserGroupAclID {
|
|
|
+ err := IsGroupValid(models.UserGroupID(srcI.Value))
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
@@ -61,19 +61,19 @@ func IsAclPolicyValid(acl models.Acl) bool {
|
|
|
|
|
|
}
|
|
|
for _, dstI := range acl.Dst {
|
|
|
- dstILi := strings.Split(dstI, ":")
|
|
|
- if len(dstILi) < 2 {
|
|
|
+
|
|
|
+ if dstI.ID == "" || dstI.Value == "" {
|
|
|
break
|
|
|
}
|
|
|
- if dstILi[0] == models.UserAclID.String() ||
|
|
|
- dstILi[0] == models.UserGroupAclID.String() {
|
|
|
+ if dstI.ID == models.UserAclID ||
|
|
|
+ dstI.ID == models.UserGroupAclID {
|
|
|
break
|
|
|
}
|
|
|
- if dstILi[0] != models.DeviceAclID.String() {
|
|
|
+ if dstI.ID != models.DeviceAclID {
|
|
|
break
|
|
|
}
|
|
|
// check if tag is valid
|
|
|
- _, err := GetTag(models.TagID(dstILi[1]))
|
|
|
+ _, err := GetTag(models.TagID(dstI.Value))
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
@@ -81,20 +81,29 @@ func IsAclPolicyValid(acl models.Acl) bool {
|
|
|
isValid = true
|
|
|
case models.DevicePolicy:
|
|
|
for _, srcI := range acl.Src {
|
|
|
- deviceTagLi := strings.Split(srcI, ":")
|
|
|
- if len(deviceTagLi) < 2 {
|
|
|
+ if srcI.ID == "" || srcI.Value == "" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if srcI.ID != models.DeviceAclID {
|
|
|
break
|
|
|
}
|
|
|
- if deviceTagLi[0] != models.DeviceAclID.String() {
|
|
|
+ // check if tag is valid
|
|
|
+ _, err := GetTag(models.TagID(srcI.Value))
|
|
|
+ if err != nil {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
for _, dstI := range acl.Dst {
|
|
|
- deviceTagLi := strings.Split(dstI, ":")
|
|
|
- if len(deviceTagLi) < 2 {
|
|
|
+
|
|
|
+ if dstI.ID == "" || dstI.Value == "" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if dstI.ID != models.DeviceAclID {
|
|
|
break
|
|
|
}
|
|
|
- if deviceTagLi[0] != models.DeviceAclID.String() {
|
|
|
+ // check if tag is valid
|
|
|
+ _, err := GetTag(models.TagID(dstI.Value))
|
|
|
+ if err != nil {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
@@ -124,6 +133,36 @@ func DeleteAcl(a models.Acl) error {
|
|
|
return database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID.String())
|
|
|
}
|
|
|
|
|
|
+func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
|
|
|
+ acls, _ := ListAcls(netID)
|
|
|
+ for _, acl := range acls {
|
|
|
+ if acl.Default && acl.RuleType == ruleType {
|
|
|
+ return acl, nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return models.Acl{}, errors.New("default rule not found")
|
|
|
+}
|
|
|
+
|
|
|
+// listDevicePolicies - lists all device policies in a network
|
|
|
+func listDevicePolicies(netID models.NetworkID) []models.Acl {
|
|
|
+ data, err := database.FetchRecords(database.TAG_TABLE_NAME)
|
|
|
+ if err != nil && !database.IsEmptyRecord(err) {
|
|
|
+ return []models.Acl{}
|
|
|
+ }
|
|
|
+ acls := []models.Acl{}
|
|
|
+ for _, dataI := range data {
|
|
|
+ acl := models.Acl{}
|
|
|
+ err := json.Unmarshal([]byte(dataI), &acl)
|
|
|
+ if err != nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy {
|
|
|
+ acls = append(acls, acl)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return acls
|
|
|
+}
|
|
|
+
|
|
|
// ListAcls - lists all acl policies
|
|
|
func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
|
|
|
data, err := database.FetchRecords(database.TAG_TABLE_NAME)
|
|
@@ -144,6 +183,47 @@ func ListAcls(netID models.NetworkID) ([]models.Acl, error) {
|
|
|
return acls, nil
|
|
|
}
|
|
|
|
|
|
+func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
|
|
|
+ aclValueMap := make(map[string]struct{})
|
|
|
+ for _, aclTagI := range acltags {
|
|
|
+ aclValueMap[aclTagI.ID.String()] = struct{}{}
|
|
|
+ }
|
|
|
+ return aclValueMap
|
|
|
+}
|
|
|
+
|
|
|
+func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
|
|
|
+ // check default policy if all allowed return true
|
|
|
+ defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
|
|
+ if err == nil {
|
|
|
+ if defaultPolicy.Enabled {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // list device policies
|
|
|
+ policies := listDevicePolicies(models.NetworkID(peer.Network))
|
|
|
+ for _, policy := range policies {
|
|
|
+ srcMap := convAclTagToValueMap(policy.Src)
|
|
|
+ dstMap := convAclTagToValueMap(policy.Dst)
|
|
|
+ for tagID := range peer.Tags {
|
|
|
+ if _, ok := dstMap[tagID.String()]; ok {
|
|
|
+ for tagID := range node.Tags {
|
|
|
+ if _, ok := srcMap[tagID.String()]; ok {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if _, ok := srcMap[tagID.String()]; ok {
|
|
|
+ for tagID := range node.Tags {
|
|
|
+ if _, ok := dstMap[tagID.String()]; ok {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|
|
|
+
|
|
|
// SortTagEntrys - Sorts slice of Tag entries by their id
|
|
|
func SortAclEntrys(acls []models.Acl) {
|
|
|
sort.Slice(acls, func(i, j int) bool {
|