Sfoglia il codice sorgente

fetch acl rules for egress networks

abhishek9686 6 mesi fa
parent
commit
38e5b58ac8
4 ha cambiato i file con 274 aggiunte e 8 eliminazioni
  1. 231 0
      logic/acls.go
  2. 40 6
      logic/peers.go
  3. 2 1
      models/acl.go
  4. 1 1
      models/mqtt.go

+ 231 - 0
logic/acls.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"maps"
+	"net"
 	"sort"
 	"sync"
 	"time"
@@ -1630,3 +1631,233 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu
 	}
 	return rules
 }
+
+func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclRule) {
+	rules = make(map[string]models.AclRule)
+
+	taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false)
+
+	acls := listDevicePolicies(models.NetworkID(targetnode.Network))
+	var targetNodeTags = make(map[models.TagID]struct{})
+	targetNodeTags["*"] = struct{}{}
+	/*
+		 if target node is egress gateway
+			if acl policy has egress route and it is present in target node egress ranges
+			fetches all the nodes in that policy and add rules
+	*/
+	if targetnode.IsEgressGateway {
+		for _, rangeI := range targetnode.EgressGatewayRanges {
+			targetNodeTags[models.TagID(rangeI)] = struct{}{}
+		}
+	}
+	for _, acl := range acls {
+		if !acl.Enabled {
+			continue
+		}
+		srcTags := convAclTagToValueMap(acl.Src)
+		dstTags := convAclTagToValueMap(acl.Dst)
+		_, srcAll := srcTags["*"]
+		_, dstAll := dstTags["*"]
+
+		for nodeTag := range targetNodeTags {
+			aclRule := models.AclRule{
+				ID:              acl.ID,
+				AllowedProtocol: acl.Proto,
+				AllowedPorts:    acl.Port,
+				Direction:       acl.AllowedDirection,
+				Allowed:         true,
+			}
+			if nodeTag != "*" {
+				ip, cidr, err := net.ParseCIDR(nodeTag.String())
+				if err != nil {
+					continue
+				}
+				if ip.To4() != nil {
+					aclRule.Dst = *cidr
+				} else {
+					aclRule.Dst6 = *cidr
+				}
+
+			} else {
+				aclRule.Dst = net.IPNet{
+					IP:   net.IPv4zero,        // 0.0.0.0
+					Mask: net.CIDRMask(0, 32), // /0 means match all IPv4
+				}
+				aclRule.Dst6 = net.IPNet{
+					IP:   net.IPv6zero,         // ::
+					Mask: net.CIDRMask(0, 128), // /0 means match all IPv6
+				}
+			}
+			if acl.AllowedDirection == models.TrafficDirectionBi {
+				var existsInSrcTag bool
+				var existsInDstTag bool
+
+				if _, ok := srcTags[nodeTag.String()]; ok || srcAll {
+					existsInSrcTag = true
+				}
+				if _, ok := srcTags[targetnode.ID.String()]; ok || srcAll {
+					existsInSrcTag = true
+				}
+				if _, ok := dstTags[nodeTag.String()]; ok || dstAll {
+					existsInDstTag = true
+				}
+				if _, ok := dstTags[targetnode.ID.String()]; ok || dstAll {
+					existsInDstTag = true
+				}
+
+				if existsInSrcTag && !existsInDstTag {
+					// get all dst tags
+					for dst := range dstTags {
+						if dst == nodeTag.String() {
+							continue
+						}
+						// Get peers in the tags and add allowed rules
+						nodes := taggedNodes[models.TagID(dst)]
+						if dst != targetnode.ID.String() {
+							node, err := GetNodeByID(dst)
+							if err == nil {
+								nodes = append(nodes, node)
+							}
+						}
+
+						for _, node := range nodes {
+							if node.ID == targetnode.ID {
+								continue
+							}
+							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+								continue
+							}
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
+							if node.IsStatic && node.StaticNode.Address != "" {
+								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+							}
+							if node.IsStatic && node.StaticNode.Address6 != "" {
+								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+							}
+						}
+					}
+				}
+				if existsInDstTag && !existsInSrcTag {
+					// get all src tags
+					for src := range srcTags {
+						if src == nodeTag.String() {
+							continue
+						}
+						// Get peers in the tags and add allowed rules
+						nodes := taggedNodes[models.TagID(src)]
+						if src != targetnode.ID.String() {
+							node, err := GetNodeByID(src)
+							if err == nil {
+								nodes = append(nodes, node)
+							}
+						}
+						for _, node := range nodes {
+							if node.ID == targetnode.ID {
+								continue
+							}
+							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+								continue
+							}
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
+							if node.IsStatic && node.StaticNode.Address != "" {
+								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+							}
+							if node.IsStatic && node.StaticNode.Address6 != "" {
+								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+							}
+						}
+					}
+				}
+				if existsInDstTag && existsInSrcTag {
+					nodes := taggedNodes[nodeTag]
+					for srcID := range srcTags {
+						if srcID == targetnode.ID.String() {
+							continue
+						}
+						node, err := GetNodeByID(srcID)
+						if err == nil {
+							nodes = append(nodes, node)
+						}
+					}
+					for dstID := range dstTags {
+						if dstID == targetnode.ID.String() {
+							continue
+						}
+						node, err := GetNodeByID(dstID)
+						if err == nil {
+							nodes = append(nodes, node)
+						}
+					}
+					for _, node := range nodes {
+						if node.ID == targetnode.ID {
+							continue
+						}
+						if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+							continue
+						}
+						if node.Address.IP != nil {
+							aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+						}
+						if node.Address6.IP != nil {
+							aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+						}
+						if node.IsStatic && node.StaticNode.Address != "" {
+							aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+						}
+						if node.IsStatic && node.StaticNode.Address6 != "" {
+							aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+						}
+					}
+				}
+			} else {
+				_, all := dstTags["*"]
+				if _, ok := dstTags[nodeTag.String()]; ok || all {
+					// get all src tags
+					for src := range srcTags {
+						if src == nodeTag.String() {
+							continue
+						}
+						// Get peers in the tags and add allowed rules
+						nodes := taggedNodes[models.TagID(src)]
+						for _, node := range nodes {
+							if node.ID == targetnode.ID {
+								continue
+							}
+							if node.IsStatic && node.StaticNode.IngressGatewayID == targetnode.ID.String() {
+								continue
+							}
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
+							if node.IsStatic && node.StaticNode.Address != "" {
+								aclRule.IPList = append(aclRule.IPList, node.StaticNode.AddressIPNet4())
+							}
+							if node.IsStatic && node.StaticNode.Address6 != "" {
+								aclRule.IP6List = append(aclRule.IP6List, node.StaticNode.AddressIPNet6())
+							}
+						}
+					}
+				}
+			}
+			if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
+				rules[acl.ID] = aclRule
+			}
+
+		}
+
+	}
+	return
+}

+ 40 - 6
logic/peers.go

@@ -181,6 +181,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 	slog.Debug("peer update for host", "hostId", host.ID.String())
 	peerIndexMap := make(map[string]int)
 	for _, nodeID := range host.Nodes {
+		networkAllowAll := true
 		nodeID := nodeID
 		node, err := GetNodeByID(nodeID)
 		if err != nil {
@@ -258,12 +259,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 			if node.NetworkRange6.IP != nil {
 				hostPeerUpdate.FwUpdate.AllowedNetworks = append(hostPeerUpdate.FwUpdate.AllowedNetworks, node.NetworkRange6)
 			}
-			if node.IsEgressGateway {
-				// get egress ranges
-				hostPeerUpdate.FwUpdate.EgressNetworks = append(hostPeerUpdate.FwUpdate.EgressNetworks, node.EgressGatewayRanges...)
-
-			}
 		} else {
+			networkAllowAll = false
 			hostPeerUpdate.FwUpdate.AllowAll = false
 			rules := GetAclRulesForNode(&node)
 			if len(hostPeerUpdate.FwUpdate.AclRules) == 0 {
@@ -478,10 +475,47 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 					Mask: getCIDRMaskFromAddr(node.Address6.IP.String()),
 				},
 				EgressGWCfg:   node.EgressGatewayRequest,
-				EgressFwRules: make(map[string]models.FwRule),
+				EgressFwRules: make(map[string]models.AclRule),
+			}
+
+		}
+		if node.IsEgressGateway {
+			if networkAllowAll {
+				// get egress ranges
+				for _, egressRangeI := range node.EgressGatewayRanges {
+					ip, cidr, err := net.ParseCIDR(egressRangeI)
+					if err != nil {
+						continue
+					}
+					egressInfo := hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()]
+					if egressInfo.EgressFwRules == nil {
+						egressInfo.EgressFwRules = make(map[string]models.AclRule)
+					}
+					if ip.To4() != nil && node.NetworkRange.IP != nil {
+
+						egressInfo.EgressFwRules[egressRangeI] = models.AclRule{
+							IPList:  []net.IPNet{node.NetworkRange},
+							Dst:     *cidr,
+							Allowed: true,
+						}
+					} else if ip.To16() != nil {
+						egressInfo.EgressFwRules[egressRangeI] = models.AclRule{
+							IP6List: []net.IPNet{node.NetworkRange6},
+							Dst6:    *cidr,
+							Allowed: true,
+						}
+					}
+				}
+			} else {
+				egressInfo := hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()]
+				if egressInfo.EgressFwRules == nil {
+					egressInfo.EgressFwRules = make(map[string]models.AclRule)
+				}
+				egressInfo.EgressFwRules = GetEgressRulesForNode(node)
 			}
 
 		}
+
 		if IsInternetGw(node) {
 			hostPeerUpdate.FwUpdate.IsEgressGw = true
 			egressrange := []string{"0.0.0.0/0"}

+ 2 - 1
models/acl.go

@@ -117,6 +117,7 @@ type AclRule struct {
 	AllowedProtocol Protocol                `json:"allowed_protocols"` // tcp, udp, etc.
 	AllowedPorts    []string                `json:"allowed_ports"`
 	Direction       AllowedTrafficDirection `json:"direction"` // single or two-way
-	EgressRanges    []net.IPNet             `json:"egress_ranges"`
+	Dst             net.IPNet               `json:"dst"`
+	Dst6            net.IPNet               `json:"dst6"`
 	Allowed         bool
 }

+ 1 - 1
models/mqtt.go

@@ -66,7 +66,7 @@ type EgressInfo struct {
 	Network6      net.IPNet            `json:"network6" yaml:"network6"`
 	EgressGwAddr6 net.IPNet            `json:"egress_gw_addr6" yaml:"egress_gw_addr6"`
 	EgressGWCfg   EgressGatewayRequest `json:"egress_gateway_cfg" yaml:"egress_gateway_cfg"`
-	EgressFwRules map[string]FwRule    `json:"egress_fw_rules"`
+	EgressFwRules map[string]AclRule   `json:"egress_fw_rules"`
 }
 
 // EgressNetworkRoutes - struct for egress network routes for adding routes to peer's interface