Browse Source

optimise static node rules, fix traffic flows for static nodes

abhishek9686 6 months ago
parent
commit
b45a926649
5 changed files with 306 additions and 134 deletions
  1. 37 6
      controllers/acls.go
  2. 157 0
      logic/acls.go
  3. 106 124
      logic/extpeers.go
  4. 3 1
      models/extclient.go
  5. 3 3
      pro/logic/status.go

+ 37 - 6
controllers/acls.go

@@ -143,18 +143,49 @@ func aclPolicyTypes(w http.ResponseWriter, r *http.Request) {
 func aclDebug(w http.ResponseWriter, r *http.Request) {
 	nodeID, _ := url.QueryUnescape(r.URL.Query().Get("node"))
 	peerID, _ := url.QueryUnescape(r.URL.Query().Get("peer"))
+	peerIsStatic, _ := url.QueryUnescape(r.URL.Query().Get("peer_is_static"))
 	node, err := logic.GetNodeByID(nodeID)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	peer, err := logic.GetNodeByID(peerID)
-	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
-		return
+	var peer models.Node
+	if peerIsStatic == "true" {
+		extclient, err := logic.GetExtClient(peerID, node.Network)
+		if err != nil {
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
+		peer = extclient.ConvertToStaticNode()
+
+	} else {
+		peer, err = logic.GetNodeByID(peerID)
+		if err != nil {
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
+	}
+	type resp struct {
+		IsNodeAllowed bool
+		IsPeerAllowed bool
+		Policies      []models.Acl
+		IngressRules  []models.FwRule
+	}
+
+	allowed, ps := logic.IsNodeAllowedToCommunicateV1(node, peer, true)
+	isallowed := logic.IsPeerAllowed(node, peer, true)
+	re := resp{
+		IsNodeAllowed: allowed,
+		IsPeerAllowed: isallowed,
+		Policies:      ps,
+	}
+	if peerIsStatic == "true" {
+		ingress, err := logic.GetNodeByID(peer.StaticNode.IngressGatewayID)
+		if err == nil {
+			re.IngressRules = logic.GetFwRulesOnIngressGateway(ingress)
+		}
 	}
-	allowed, _ := logic.IsNodeAllowedToCommunicate(node, peer, true)
-	logic.ReturnSuccessResponseWithJson(w, r, allowed, "fetched all acls in the network ")
+	logic.ReturnSuccessResponseWithJson(w, r, re, "fetched all acls in the network ")
 }
 
 // @Summary     List Acls in a network

+ 157 - 0
logic/acls.go

@@ -822,6 +822,163 @@ func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.N
 	}
 	return false
 }
+func uniquePolicies(items []models.Acl) []models.Acl {
+	if len(items) == 0 {
+		return items
+	}
+	seen := make(map[string]bool)
+	var result []models.Acl
+	for _, item := range items {
+		if !seen[item.ID] {
+			seen[item.ID] = true
+			result = append(result, item)
+		}
+	}
+
+	return result
+}
+
+// IsNodeAllowedToCommunicate - check node is allowed to communicate with the peer // ADD ALLOWED DIRECTION - 0 => node -> peer, 1 => peer-> node,
+func IsNodeAllowedToCommunicateV1(node, peer models.Node, checkDefaultPolicy bool) (bool, []models.Acl) {
+	var nodeId, peerId string
+	if node.IsStatic {
+		nodeId = node.StaticNode.ClientID
+		node = node.StaticNode.ConvertToStaticNode()
+	} else {
+		nodeId = node.ID.String()
+	}
+	if peer.IsStatic {
+		peerId = peer.StaticNode.ClientID
+		peer = peer.StaticNode.ConvertToStaticNode()
+	} else {
+		peerId = peer.ID.String()
+	}
+
+	aclTagsMutex.RLock()
+	peerTags := maps.Clone(peer.Tags)
+	nodeTags := maps.Clone(node.Tags)
+	aclTagsMutex.RUnlock()
+	nodeTags[models.TagID(nodeId)] = struct{}{}
+	peerTags[models.TagID(peerId)] = struct{}{}
+	if checkDefaultPolicy {
+		// check default policy if all allowed return true
+		defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+		if err == nil {
+			if defaultPolicy.Enabled {
+				return true, []models.Acl{defaultPolicy}
+			}
+		}
+	}
+	allowedPolicies := []models.Acl{}
+	defer func() {
+		allowedPolicies = uniquePolicies(allowedPolicies)
+	}()
+	// list device policies
+	policies := listDevicePolicies(models.NetworkID(peer.Network))
+	srcMap := make(map[string]struct{})
+	dstMap := make(map[string]struct{})
+	defer func() {
+		srcMap = nil
+		dstMap = nil
+	}()
+	for _, policy := range policies {
+		if !policy.Enabled {
+			continue
+		}
+		allowed := false
+		srcMap = convAclTagToValueMap(policy.Src)
+		dstMap = convAclTagToValueMap(policy.Dst)
+		_, srcAll := srcMap["*"]
+		_, dstAll := dstMap["*"]
+		if policy.AllowedDirection == models.TrafficDirectionBi {
+			if _, ok := srcMap[nodeId]; ok || srcAll {
+				if _, ok := dstMap[peerId]; ok || dstAll {
+					allowedPolicies = append(allowedPolicies, policy)
+					continue
+				}
+
+			}
+			if _, ok := dstMap[nodeId]; ok || dstAll {
+				if _, ok := srcMap[peerId]; ok || srcAll {
+					allowedPolicies = append(allowedPolicies, policy)
+					continue
+				}
+			}
+		}
+		if _, ok := dstMap[peerId]; ok || dstAll {
+			if _, ok := srcMap[nodeId]; ok || srcAll {
+				allowedPolicies = append(allowedPolicies, policy)
+				continue
+			}
+		}
+		if policy.AllowedDirection == models.TrafficDirectionBi {
+
+			for tagID := range nodeTags {
+
+				if _, ok := dstMap[tagID.String()]; ok {
+					if srcAll {
+						allowed = true
+						break
+					}
+					for tagID := range peerTags {
+						if _, ok := srcMap[tagID.String()]; ok {
+							allowed = true
+							break
+						}
+					}
+				}
+				if allowed {
+					allowedPolicies = append(allowedPolicies, policy)
+					break
+				}
+				if _, ok := srcMap[tagID.String()]; ok {
+					if dstAll {
+						allowed = true
+						break
+					}
+					for tagID := range peerTags {
+						if _, ok := dstMap[tagID.String()]; ok {
+							allowed = true
+							break
+						}
+					}
+				}
+				if allowed {
+					break
+				}
+			}
+			if allowed {
+				allowedPolicies = append(allowedPolicies, policy)
+				continue
+			}
+		}
+		for tagID := range peerTags {
+			if _, ok := dstMap[tagID.String()]; ok {
+				if srcAll {
+					allowed = true
+					break
+				}
+				for tagID := range nodeTags {
+					if _, ok := srcMap[tagID.String()]; ok {
+						allowed = true
+						break
+					}
+				}
+			}
+			if allowed {
+				break
+			}
+		}
+		if allowed {
+			allowedPolicies = append(allowedPolicies, policy)
+		}
+	}
+
+	if len(allowedPolicies) > 0 {
+		return true, allowedPolicies
+	}
+	return false, allowedPolicies
+}
 
 // IsNodeAllowedToCommunicate - check node is allowed to communicate with the peer
 func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) (bool, []models.Acl) {

+ 106 - 124
logic/extpeers.go

@@ -455,9 +455,102 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
 	return
 }
 
+func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []models.Acl) (rules []models.FwRule) {
+
+	for _, policy := range allowedPolicies {
+		rules = append(rules, models.FwRule{
+			SrcIP: net.IPNet{
+				IP:   node.Address.IP,
+				Mask: net.CIDRMask(32, 32),
+			},
+			DstIP: net.IPNet{
+				IP:   peer.Address.IP,
+				Mask: net.CIDRMask(32, 32),
+			},
+			AllowedProtocol: policy.Proto,
+			AllowedPorts:    policy.Port,
+			Allow:           true,
+		})
+		if policy.AllowedDirection == models.TrafficDirectionBi {
+			rules = append(rules, models.FwRule{
+				SrcIP: net.IPNet{
+					IP:   peer.Address.IP,
+					Mask: net.CIDRMask(32, 32),
+				},
+				DstIP: net.IPNet{
+					IP:   node.Address.IP,
+					Mask: net.CIDRMask(32, 32),
+				},
+				AllowedProtocol: policy.Proto,
+				AllowedPorts:    policy.Port,
+				Allow:           true,
+			})
+		}
+		if len(node.StaticNode.ExtraAllowedIPs) > 0 {
+			for _, additionalAllowedIPNet := range node.StaticNode.ExtraAllowedIPs {
+				_, ipNet, err := net.ParseCIDR(additionalAllowedIPNet)
+				if err != nil {
+					continue
+				}
+				if ipNet.IP.To4() != nil {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   peer.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				} else {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   peer.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				}
+
+			}
+
+		}
+		if len(peer.StaticNode.ExtraAllowedIPs) > 0 {
+			for _, additionalAllowedIPNet := range peer.StaticNode.ExtraAllowedIPs {
+				_, ipNet, err := net.ParseCIDR(additionalAllowedIPNet)
+				if err != nil {
+					continue
+				}
+				if ipNet.IP.To4() != nil {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   node.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				} else {
+					rules = append(rules, models.FwRule{
+						SrcIP: net.IPNet{
+							IP:   node.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						},
+						DstIP: *ipNet,
+						Allow: true,
+					})
+				}
+
+			}
+
+		}
+	}
+
+	return
+}
+
 func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 	// fetch user access to static clients via policies
-
 	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 	nodes, _ := GetNetworkNodes(node.Network)
@@ -584,131 +677,18 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 		if !nodeI.IsStatic || nodeI.IsUserNode {
 			continue
 		}
+		if nodeI.StaticNode.IngressGatewayID != node.ID.String() {
+			continue
+		}
 		for _, peer := range nodes {
 			if peer.StaticNode.ClientID == nodeI.StaticNode.ClientID || peer.IsUserNode {
 				continue
 			}
-			if ok, allowedPolicies := IsNodeAllowedToCommunicate(nodeI, peer, true); ok {
-				if peer.IsStatic {
-					if nodeI.StaticNode.Address != "" {
-						for _, policy := range allowedPolicies {
-							rules = append(rules, models.FwRule{
-								SrcIP:           nodeI.StaticNode.AddressIPNet4(),
-								DstIP:           peer.StaticNode.AddressIPNet4(),
-								AllowedProtocol: policy.Proto,
-								AllowedPorts:    policy.Port,
-								Allow:           true,
-							})
-							if policy.AllowedDirection == models.TrafficDirectionBi {
-								rules = append(rules, models.FwRule{
-									SrcIP:           peer.StaticNode.AddressIPNet4(),
-									DstIP:           nodeI.StaticNode.AddressIPNet4(),
-									AllowedProtocol: policy.Proto,
-									AllowedPorts:    policy.Port,
-									Allow:           true,
-								})
-							}
-						}
-
-					}
-					if nodeI.StaticNode.Address6 != "" {
-						for _, policy := range allowedPolicies {
-							rules = append(rules, models.FwRule{
-								SrcIP:           nodeI.StaticNode.AddressIPNet6(),
-								DstIP:           peer.StaticNode.AddressIPNet6(),
-								AllowedProtocol: policy.Proto,
-								AllowedPorts:    policy.Port,
-								Allow:           true,
-							})
-							if policy.AllowedDirection == models.TrafficDirectionBi {
-								rules = append(rules, models.FwRule{
-									SrcIP:           peer.StaticNode.AddressIPNet6(),
-									DstIP:           nodeI.StaticNode.AddressIPNet6(),
-									AllowedProtocol: policy.Proto,
-									AllowedPorts:    policy.Port,
-									Allow:           true,
-								})
-							}
-						}
-					}
-					if len(peer.StaticNode.ExtraAllowedIPs) > 0 {
-						for _, additionalAllowedIPNet := range peer.StaticNode.ExtraAllowedIPs {
-							_, ipNet, err := net.ParseCIDR(additionalAllowedIPNet)
-							if err != nil {
-								continue
-							}
-							if ipNet.IP.To4() != nil {
-								rules = append(rules, models.FwRule{
-									SrcIP: nodeI.StaticNode.AddressIPNet4(),
-									DstIP: *ipNet,
-									Allow: true,
-								})
-							} else {
-								rules = append(rules, models.FwRule{
-									SrcIP: nodeI.StaticNode.AddressIPNet6(),
-									DstIP: *ipNet,
-									Allow: true,
-								})
-							}
-
-						}
-
-					}
-				} else {
-					if nodeI.StaticNode.Address != "" {
-						for _, policy := range allowedPolicies {
-							rules = append(rules, models.FwRule{
-								SrcIP: nodeI.StaticNode.AddressIPNet4(),
-								DstIP: net.IPNet{
-									IP:   peer.Address.IP,
-									Mask: net.CIDRMask(32, 32),
-								},
-								AllowedProtocol: policy.Proto,
-								AllowedPorts:    policy.Port,
-								Allow:           true,
-							})
-							if policy.AllowedDirection == models.TrafficDirectionBi {
-								rules = append(rules, models.FwRule{
-									SrcIP: net.IPNet{
-										IP:   peer.Address.IP,
-										Mask: net.CIDRMask(32, 32),
-									},
-									DstIP:           nodeI.StaticNode.AddressIPNet4(),
-									AllowedProtocol: policy.Proto,
-									AllowedPorts:    policy.Port,
-									Allow:           true,
-								})
-							}
-						}
-					}
-					if nodeI.StaticNode.Address6 != "" {
-						for _, policy := range allowedPolicies {
-							rules = append(rules, models.FwRule{
-								SrcIP: nodeI.StaticNode.AddressIPNet6(),
-								DstIP: net.IPNet{
-									IP:   peer.Address6.IP,
-									Mask: net.CIDRMask(128, 128),
-								},
-								AllowedProtocol: policy.Proto,
-								AllowedPorts:    policy.Port,
-								Allow:           true,
-							})
-							if policy.AllowedDirection == models.TrafficDirectionBi {
-								rules = append(rules, models.FwRule{
-									SrcIP: net.IPNet{
-										IP:   peer.Address6.IP,
-										Mask: net.CIDRMask(128, 128),
-									},
-									DstIP:           nodeI.StaticNode.AddressIPNet6(),
-									AllowedProtocol: policy.Proto,
-									AllowedPorts:    policy.Port,
-									Allow:           true,
-								})
-							}
-						}
-					}
-				}
-
+			if ok, allowedPolicies := IsNodeAllowedToCommunicateV1(nodeI.StaticNode.ConvertToStaticNode(), peer, true); ok {
+				rules = append(rules, getFwRulesForNodeAndPeerOnGw(nodeI.StaticNode.ConvertToStaticNode(), peer, allowedPolicies)...)
+			}
+			if ok, allowedPolicies := IsNodeAllowedToCommunicateV1(peer, nodeI.StaticNode.ConvertToStaticNode(), true); ok {
+				rules = append(rules, getFwRulesForNodeAndPeerOnGw(peer, nodeI.StaticNode.ConvertToStaticNode(), allowedPolicies)...)
 			}
 		}
 	}
@@ -729,11 +709,13 @@ func GetExtPeers(node, peer *models.Node) ([]wgtypes.PeerConfig, []models.IDandA
 	}
 	for _, extPeer := range extPeers {
 		extPeer := extPeer
+		fmt.Println("=====> checking EXT peer: ", extPeer.ClientID)
 		if !IsClientNodeAllowed(&extPeer, peer.ID.String()) {
 			continue
 		}
 		if extPeer.RemoteAccessClientID == "" {
 			if ok := IsPeerAllowed(extPeer.ConvertToStaticNode(), *peer, true); !ok {
+				fmt.Println("=====>1 checking EXT peer: ", extPeer.ClientID)
 				continue
 			}
 		} else {
@@ -822,7 +804,7 @@ func getExtpeerEgressRanges(node models.Node) (ranges, ranges6 []net.IPNet) {
 		if len(extPeer.ExtraAllowedIPs) == 0 {
 			continue
 		}
-		if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node, true); !ok {
+		if ok, _ := IsNodeAllowedToCommunicateV1(extPeer.ConvertToStaticNode(), node, true); !ok {
 			continue
 		}
 		for _, allowedRange := range extPeer.ExtraAllowedIPs {
@@ -849,7 +831,7 @@ func getExtpeersExtraRoutes(node models.Node) (egressRoutes []models.EgressNetwo
 		if len(extPeer.ExtraAllowedIPs) == 0 {
 			continue
 		}
-		if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node, true); !ok {
+		if ok, _ := IsNodeAllowedToCommunicateV1(extPeer.ConvertToStaticNode(), node, true); !ok {
 			continue
 		}
 		egressRoutes = append(egressRoutes, getExtPeerEgressRoute(node, extPeer)...)

+ 3 - 1
models/extclient.go

@@ -50,7 +50,9 @@ func (ext *ExtClient) ConvertToStaticNode() Node {
 
 	return Node{
 		CommonNode: CommonNode{
-			Network: ext.Network,
+			Network:  ext.Network,
+			Address:  ext.AddressIPNet4(),
+			Address6: ext.AddressIPNet6(),
 		},
 		Tags:       ext.Tags,
 		IsStatic:   true,

+ 3 - 3
pro/logic/status.go

@@ -41,7 +41,7 @@ func GetNodeStatus(node *models.Node, defaultEnabledPolicy bool) {
 			return
 		}
 		if !defaultEnabledPolicy {
-			allowed, _ := logic.IsNodeAllowedToCommunicate(*node, ingNode, false)
+			allowed, _ := logic.IsNodeAllowedToCommunicateV1(*node, ingNode, false)
 			if !allowed {
 				node.Status = models.OnlineSt
 				return
@@ -161,7 +161,7 @@ func checkPeerStatus(node *models.Node, defaultAclPolicy bool) {
 		}
 
 		if !defaultAclPolicy {
-			allowed, _ := logic.IsNodeAllowedToCommunicate(*node, peer, false)
+			allowed, _ := logic.IsNodeAllowedToCommunicateV1(*node, peer, false)
 			if !allowed {
 				continue
 			}
@@ -199,7 +199,7 @@ func checkPeerConnectivity(node *models.Node, metrics *models.Metrics, defaultAc
 		}
 
 		if !defaultAclPolicy {
-			allowed, _ := logic.IsNodeAllowedToCommunicate(*node, peer, false)
+			allowed, _ := logic.IsNodeAllowedToCommunicateV1(*node, peer, false)
 			if !allowed {
 				continue
 			}