Переглянути джерело

Merge pull request #3238 from gravitl/NET-1784-userRules

Net 1784 user rules
Abhishek K 9 місяців тому
батько
коміт
e47baab30e
6 змінених файлів з 210 додано та 128 видалено
  1. 1 1
      controllers/acls.go
  2. 93 53
      logic/acls.go
  3. 98 60
      logic/extpeers.go
  4. 6 4
      logic/peers.go
  5. 7 7
      models/acl.go
  6. 5 3
      models/mqtt.go

+ 1 - 1
controllers/acls.go

@@ -136,7 +136,7 @@ func aclDebug(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	allowed := logic.IsNodeAllowedToCommunicate(node, peer)
+	allowed, _ := logic.IsNodeAllowedToCommunicate(node, peer)
 	logic.ReturnSuccessResponseWithJson(w, r, allowed, "fetched all acls in the network ")
 }
 

+ 93 - 53
logic/acls.go

@@ -34,21 +34,21 @@ func MigrateDefaulAclPolicies(netID models.NetworkID) {
 	}
 	acl, err = GetAcl(fmt.Sprintf("%s.%s", netID, "all-users"))
 	if err == nil {
-		if acl.Proto.String() == "" {
-			acl.Proto = models.ALL
-			acl.ServiceType = models.Custom
-			acl.Port = []string{}
-			UpsertAcl(acl)
-		}
+		//if acl.Proto.String() == "" {
+		acl.Proto = models.ALL
+		acl.ServiceType = models.Custom
+		acl.Port = []string{}
+		UpsertAcl(acl)
+		//}
 	}
 	acl, err = GetAcl(fmt.Sprintf("%s.%s", netID, "all-remote-access-gws"))
 	if err == nil {
-		if acl.Proto.String() == "" {
-			acl.Proto = models.ALL
-			acl.ServiceType = models.Custom
-			acl.Port = []string{}
-			UpsertAcl(acl)
-		}
+		//if acl.Proto.String() == "" {
+		acl.Proto = models.ALL
+		acl.ServiceType = models.Custom
+		acl.Port = []string{}
+		UpsertAcl(acl)
+		//}
 	}
 }
 
@@ -526,19 +526,19 @@ func convAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} {
 }
 
 // IsUserAllowedToCommunicate - check if user is allowed to communicate with peer
-func IsUserAllowedToCommunicate(userName string, peer models.Node) bool {
+func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []models.Acl) {
 	if peer.IsStatic {
 		peer = peer.StaticNode.ConvertToStaticNode()
 	}
 	acl, _ := GetDefaultPolicy(models.NetworkID(peer.Network), models.UserPolicy)
 	if acl.Enabled {
-		return true
+		return true, []models.Acl{acl}
 	}
 	user, err := GetUser(userName)
 	if err != nil {
-		return false
+		return false, []models.Acl{}
 	}
-
+	allowedPolicies := []models.Acl{}
 	policies := listPoliciesOfUser(*user, models.NetworkID(peer.Network))
 	for _, policy := range policies {
 		if !policy.Enabled {
@@ -546,20 +546,25 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) bool {
 		}
 		dstMap := convAclTagToValueMap(policy.Dst)
 		if _, ok := dstMap["*"]; ok {
-			return true
+			allowedPolicies = append(allowedPolicies, policy)
+			continue
 		}
 		for tagID := range peer.Tags {
 			if _, ok := dstMap[tagID.String()]; ok {
-				return true
+				allowedPolicies = append(allowedPolicies, policy)
+				break
 			}
 		}
 
 	}
-	return false
+	if len(allowedPolicies) > 0 {
+		return true, allowedPolicies
+	}
+	return false, []models.Acl{}
 }
 
 // IsNodeAllowedToCommunicate - check node is allowed to communicate with the peer
-func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
+func IsNodeAllowedToCommunicate(node, peer models.Node) (bool, []models.Acl) {
 	if node.IsStatic {
 		node = node.StaticNode.ConvertToStaticNode()
 	}
@@ -570,10 +575,10 @@ func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
 	defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 	if err == nil {
 		if defaultPolicy.Enabled {
-			return true
+			return true, []models.Acl{defaultPolicy}
 		}
 	}
-
+	allowedPolicies := []models.Acl{}
 	// list device policies
 	policies := listDevicePolicies(models.NetworkID(peer.Network))
 	for _, policy := range policies {
@@ -587,52 +592,86 @@ func IsNodeAllowedToCommunicate(node, peer models.Node) bool {
 		// fmt.Printf("\n======> node Tags: %+v\n", node.Tags)
 		// fmt.Printf("\n======> peer Tags: %+v\n", peer.Tags)
 		for tagID := range node.Tags {
+			allowed := false
 			if _, ok := dstMap[tagID.String()]; ok {
 				if _, ok := srcMap["*"]; ok {
-					return true
+					allowed = true
+					allowedPolicies = append(allowedPolicies, policy)
+					break
 				}
 				for tagID := range peer.Tags {
 					if _, ok := srcMap[tagID.String()]; ok {
-						return true
+						allowed = true
+						break
 					}
 				}
 			}
+			if allowed {
+				allowedPolicies = append(allowedPolicies, policy)
+				break
+			}
 			if _, ok := srcMap[tagID.String()]; ok {
 				if _, ok := dstMap["*"]; ok {
-					return true
+					allowed = true
+					allowedPolicies = append(allowedPolicies, policy)
+					break
 				}
 				for tagID := range peer.Tags {
 					if _, ok := dstMap[tagID.String()]; ok {
-						return true
+						allowed = true
+						break
 					}
 				}
 			}
+			if allowed {
+				allowedPolicies = append(allowedPolicies, policy)
+				break
+			}
 		}
 		for tagID := range peer.Tags {
+			allowed := false
 			if _, ok := dstMap[tagID.String()]; ok {
 				if _, ok := srcMap["*"]; ok {
-					return true
+					allowed = true
+					allowedPolicies = append(allowedPolicies, policy)
+					break
 				}
 				for tagID := range node.Tags {
 
 					if _, ok := srcMap[tagID.String()]; ok {
-						return true
+						allowed = true
+						break
 					}
 				}
 			}
+			if allowed {
+				allowedPolicies = append(allowedPolicies, policy)
+				break
+			}
+
 			if _, ok := srcMap[tagID.String()]; ok {
 				if _, ok := dstMap["*"]; ok {
-					return true
+					allowed = true
+					allowedPolicies = append(allowedPolicies, policy)
+					break
 				}
 				for tagID := range node.Tags {
 					if _, ok := dstMap[tagID.String()]; ok {
-						return true
+						allowed = true
+						break
 					}
 				}
 			}
+			if allowed {
+				allowedPolicies = append(allowedPolicies, policy)
+				break
+			}
 		}
 	}
-	return false
+	if len(allowedPolicies) > 0 {
+		return true, allowedPolicies
+	}
+	return false, allowedPolicies
 }
 
 // SortTagEntrys - Sorts slice of Tag entries by their id
@@ -720,13 +759,14 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
 	defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 	rules = make(map[string]models.AclRule)
 	if err == nil && defaultPolicy.Enabled {
+		return
 		return map[string]models.AclRule{
 			defaultPolicy.ID: {
-				IPList:           []net.IPNet{node.NetworkRange},
-				IP6List:          []net.IPNet{node.NetworkRange6},
-				AllowedProtocols: models.ALL,
-				Direction:        models.TrafficDirectionBi,
-				Allowed:          true,
+				IPList:          []net.IPNet{node.NetworkRange},
+				IP6List:         []net.IPNet{node.NetworkRange6},
+				AllowedProtocol: models.ALL,
+				Direction:       models.TrafficDirectionBi,
+				Allowed:         true,
 			},
 		}
 	}
@@ -742,11 +782,11 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
 			srcTags := convAclTagToValueMap(acl.Src)
 			dstTags := convAclTagToValueMap(acl.Dst)
 			aclRule := models.AclRule{
-				ID:               acl.ID,
-				AllowedProtocols: acl.Proto,
-				AllowedPorts:     acl.Port,
-				Direction:        acl.AllowedDirection,
-				Allowed:          true,
+				ID:              acl.ID,
+				AllowedProtocol: acl.Proto,
+				AllowedPorts:    acl.Port,
+				Direction:       acl.AllowedDirection,
+				Allowed:         true,
 			}
 			if acl.AllowedDirection == models.TrafficDirectionBi {
 				var existsInSrcTag bool
@@ -755,24 +795,24 @@ func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
 				if _, ok := srcTags["*"]; ok {
 					return map[string]models.AclRule{
 						acl.ID: {
-							IPList:           []net.IPNet{node.NetworkRange},
-							IP6List:          []net.IPNet{node.NetworkRange6},
-							AllowedProtocols: models.ALL,
-							AllowedPorts:     acl.Port,
-							Direction:        acl.AllowedDirection,
-							Allowed:          true,
+							IPList:          []net.IPNet{node.NetworkRange},
+							IP6List:         []net.IPNet{node.NetworkRange6},
+							AllowedProtocol: models.ALL,
+							AllowedPorts:    acl.Port,
+							Direction:       acl.AllowedDirection,
+							Allowed:         true,
 						},
 					}
 				}
 				if _, ok := dstTags["*"]; ok {
 					return map[string]models.AclRule{
 						acl.ID: {
-							IPList:           []net.IPNet{node.NetworkRange},
-							IP6List:          []net.IPNet{node.NetworkRange6},
-							AllowedProtocols: models.ALL,
-							AllowedPorts:     acl.Port,
-							Direction:        acl.AllowedDirection,
-							Allowed:          true,
+							IPList:          []net.IPNet{node.NetworkRange},
+							IP6List:         []net.IPNet{node.NetworkRange6},
+							AllowedProtocol: models.ALL,
+							AllowedPorts:    acl.Port,
+							Direction:       acl.AllowedDirection,
+							Allowed:         true,
 						},
 					}
 				}

+ 98 - 60
logic/extpeers.go

@@ -457,7 +457,7 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
 func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 	// fetch user access to static clients via policies
 	defer func() {
-		logger.Log(0, fmt.Sprintf("%+v\n", rules))
+		logger.Log(0, fmt.Sprintf("node.ID: %s, Rules: %+v\n", node.ID, rules))
 	}()
 	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@@ -471,15 +471,21 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 			if peer.IsUserNode {
 				continue
 			}
-			if IsUserAllowedToCommunicate(userNodeI.StaticNode.OwnerID, peer) {
+			if ok, allowedPolicies := IsUserAllowedToCommunicate(userNodeI.StaticNode.OwnerID, peer); ok {
 				if peer.IsStatic {
 					if userNodeI.StaticNode.Address != "" {
 						if !defaultUserPolicy.Enabled {
-							rules = append(rules, models.FwRule{
-								SrcIP: userNodeI.StaticNode.AddressIPNet4(),
-								DstIP: peer.StaticNode.AddressIPNet4(),
-								Allow: true,
-							})
+							for _, policy := range allowedPolicies {
+								rules = append(rules, models.FwRule{
+									SrcIP:           userNodeI.StaticNode.AddressIPNet4(),
+									DstIP:           peer.StaticNode.AddressIPNet4(),
+									AllowedProtocol: policy.Proto,
+									AllowedPorts:    policy.Port,
+									Allow:           true,
+								})
+
+							}
+
 						}
 						rules = append(rules, models.FwRule{
 							SrcIP: peer.StaticNode.AddressIPNet4(),
@@ -489,11 +495,16 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 					}
 					if userNodeI.StaticNode.Address6 != "" {
 						if !defaultUserPolicy.Enabled {
-							rules = append(rules, models.FwRule{
-								SrcIP: userNodeI.StaticNode.AddressIPNet6(),
-								DstIP: peer.StaticNode.AddressIPNet6(),
-								Allow: true,
-							})
+							for _, policy := range allowedPolicies {
+								rules = append(rules, models.FwRule{
+									SrcIP:           userNodeI.StaticNode.AddressIPNet6(),
+									DstIP:           peer.StaticNode.AddressIPNet6(),
+									Allow:           true,
+									AllowedProtocol: policy.Proto,
+									AllowedPorts:    policy.Port,
+								})
+
+							}
 						}
 
 						rules = append(rules, models.FwRule{
@@ -529,29 +540,39 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 
 					if userNodeI.StaticNode.Address != "" {
 						if !defaultUserPolicy.Enabled {
-							rules = append(rules, models.FwRule{
-								SrcIP: userNodeI.StaticNode.AddressIPNet4(),
-								DstIP: net.IPNet{
-									IP:   peer.Address.IP,
-									Mask: net.CIDRMask(32, 32),
-								},
-								Allow: true,
-							})
+							for _, policy := range allowedPolicies {
+								rules = append(rules, models.FwRule{
+									SrcIP: userNodeI.StaticNode.AddressIPNet4(),
+									DstIP: net.IPNet{
+										IP:   peer.Address.IP,
+										Mask: net.CIDRMask(32, 32),
+									},
+									AllowedProtocol: policy.Proto,
+									AllowedPorts:    policy.Port,
+									Allow:           true,
+								})
+							}
+
 						}
 					}
 
 					if userNodeI.StaticNode.Address6 != "" {
-						rules = append(rules, models.FwRule{
-							SrcIP: userNodeI.StaticNode.AddressIPNet6(),
-							DstIP: net.IPNet{
-								IP:   peer.Address6.IP,
-								Mask: net.CIDRMask(128, 128),
-							},
-							Allow: true,
-						})
+						if !defaultUserPolicy.Enabled {
+							for _, policy := range allowedPolicies {
+								rules = append(rules, models.FwRule{
+									SrcIP: userNodeI.StaticNode.AddressIPNet6(),
+									DstIP: net.IPNet{
+										IP:   peer.Address6.IP,
+										Mask: net.CIDRMask(128, 128),
+									},
+									AllowedProtocol: policy.Proto,
+									AllowedPorts:    policy.Port,
+									Allow:           true,
+								})
+							}
+						}
 					}
 				}
-
 			}
 		}
 	}
@@ -567,21 +588,30 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 			if peer.StaticNode.ClientID == nodeI.StaticNode.ClientID || peer.IsUserNode {
 				continue
 			}
-			if IsNodeAllowedToCommunicate(nodeI, peer) {
+			if ok, allowedPolicies := IsNodeAllowedToCommunicate(nodeI, peer); ok {
 				if peer.IsStatic {
 					if nodeI.StaticNode.Address != "" {
-						rules = append(rules, models.FwRule{
-							SrcIP: nodeI.StaticNode.AddressIPNet4(),
-							DstIP: peer.StaticNode.AddressIPNet4(),
-							Allow: true,
-						})
+						for _, policy := range allowedPolicies {
+							rules = append(rules, models.FwRule{
+								SrcIP:           nodeI.StaticNode.AddressIPNet4(),
+								DstIP:           peer.StaticNode.AddressIPNet4(),
+								AllowedProtocol: policy.Proto,
+								AllowedPorts:    policy.Port,
+								Allow:           true,
+							})
+						}
+
 					}
 					if nodeI.StaticNode.Address6 != "" {
-						rules = append(rules, models.FwRule{
-							SrcIP: nodeI.StaticNode.AddressIPNet6(),
-							DstIP: peer.StaticNode.AddressIPNet6(),
-							Allow: true,
-						})
+						for _, policy := range allowedPolicies {
+							rules = append(rules, models.FwRule{
+								SrcIP:           nodeI.StaticNode.AddressIPNet6(),
+								DstIP:           peer.StaticNode.AddressIPNet6(),
+								AllowedProtocol: policy.Proto,
+								AllowedPorts:    policy.Port,
+								Allow:           true,
+							})
+						}
 					}
 					if len(peer.StaticNode.ExtraAllowedIPs) > 0 {
 						for _, additionalAllowedIPNet := range peer.StaticNode.ExtraAllowedIPs {
@@ -608,24 +638,32 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 					}
 				} else {
 					if nodeI.StaticNode.Address != "" {
-						rules = append(rules, models.FwRule{
-							SrcIP: nodeI.StaticNode.AddressIPNet4(),
-							DstIP: net.IPNet{
-								IP:   peer.Address.IP,
-								Mask: net.CIDRMask(32, 32),
-							},
-							Allow: true,
-						})
+						for _, policy := range allowedPolicies {
+							rules = append(rules, models.FwRule{
+								SrcIP: nodeI.StaticNode.AddressIPNet4(),
+								DstIP: net.IPNet{
+									IP:   peer.Address.IP,
+									Mask: net.CIDRMask(32, 32),
+								},
+								AllowedProtocol: policy.Proto,
+								AllowedPorts:    policy.Port,
+								Allow:           true,
+							})
+						}
 					}
 					if nodeI.StaticNode.Address6 != "" {
-						rules = append(rules, models.FwRule{
-							SrcIP: nodeI.StaticNode.AddressIPNet6(),
-							DstIP: net.IPNet{
-								IP:   peer.Address6.IP,
-								Mask: net.CIDRMask(128, 128),
-							},
-							Allow: true,
-						})
+						for _, policy := range allowedPolicies {
+							rules = append(rules, models.FwRule{
+								SrcIP: nodeI.StaticNode.AddressIPNet6(),
+								DstIP: net.IPNet{
+									IP:   peer.Address6.IP,
+									Mask: net.CIDRMask(128, 128),
+								},
+								AllowedProtocol: policy.Proto,
+								AllowedPorts:    policy.Port,
+								Allow:           true,
+							})
+						}
 					}
 				}
 
@@ -653,11 +691,11 @@ func GetExtPeers(node, peer *models.Node) ([]wgtypes.PeerConfig, []models.IDandA
 			continue
 		}
 		if extPeer.RemoteAccessClientID == "" {
-			if !IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), *peer) {
+			if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), *peer); !ok {
 				continue
 			}
 		} else {
-			if !IsUserAllowedToCommunicate(extPeer.OwnerID, *peer) {
+			if ok, _ := IsUserAllowedToCommunicate(extPeer.OwnerID, *peer); !ok {
 				continue
 			}
 		}
@@ -742,7 +780,7 @@ func getExtpeerEgressRanges(node models.Node) (ranges, ranges6 []net.IPNet) {
 		if len(extPeer.ExtraAllowedIPs) == 0 {
 			continue
 		}
-		if !IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node) {
+		if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node); !ok {
 			continue
 		}
 		for _, allowedRange := range extPeer.ExtraAllowedIPs {
@@ -769,7 +807,7 @@ func getExtpeersExtraRoutes(node models.Node) (egressRoutes []models.EgressNetwo
 		if len(extPeer.ExtraAllowedIPs) == 0 {
 			continue
 		}
-		if !IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node) {
+		if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node); !ok {
 			continue
 		}
 		egressRoutes = append(egressRoutes, getExtPeerEgressRoute(node, extPeer)...)

+ 6 - 4
logic/peers.go

@@ -164,14 +164,15 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		if node.NetworkRange6.IP != nil {
 			hostPeerUpdate.FwUpdate.Networks = append(hostPeerUpdate.FwUpdate.Networks, node.NetworkRange6)
 		}
-		if host.Name == "Test-Server" {
-			fmt.Println("##### DEF POL ", defaultDevicePolicy.Enabled, defaultUserPolicy.Enabled)
-		}
 
 		if !defaultDevicePolicy.Enabled || !defaultUserPolicy.Enabled {
 			hostPeerUpdate.FwUpdate.AllowAll = false
 		}
 		hostPeerUpdate.FwUpdate.AclRules = GetAclRulesForNode(&node)
+		if host.Name == "Test-Server" {
+			fmt.Println("##### DEF POL ", defaultDevicePolicy.Enabled, defaultUserPolicy.Enabled)
+			fmt.Printf("ACL Rules: %+v\n", hostPeerUpdate.FwUpdate.AclRules)
+		}
 		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
 		for _, peer := range currentPeers {
 			peer := peer
@@ -273,11 +274,12 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				peerConfig.Endpoint.Port = peerHost.ListenPort
 			}
 			allowedips := GetAllowedIPs(&node, &peer, nil)
+			allowedToComm, _ := IsNodeAllowedToCommunicate(node, peer)
 			if peer.Action != models.NODE_DELETE &&
 				!peer.PendingDelete &&
 				peer.Connected &&
 				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
-				IsNodeAllowedToCommunicate(node, peer) &&
+				allowedToComm &&
 				(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
 				peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
 			}

+ 7 - 7
models/acl.go

@@ -109,11 +109,11 @@ type ProtocolType struct {
 }
 
 type AclRule struct {
-	ID               string                  `json:"id"`
-	IPList           []net.IPNet             `json:"ip_list"`
-	IP6List          []net.IPNet             `json:"ip6_list"`
-	AllowedProtocols Protocol                `json:"allowed_protocols"` // tcp, udp, etc.
-	AllowedPorts     []string                `json:"allowed_ports"`
-	Direction        AllowedTrafficDirection `json:"direction"` // single or two-way
-	Allowed          bool
+	ID              string                  `json:"id"`
+	IPList          []net.IPNet             `json:"ip_list"`
+	IP6List         []net.IPNet             `json:"ip6_list"`
+	AllowedProtocol Protocol                `json:"allowed_protocols"` // tcp, udp, etc.
+	AllowedPorts    []string                `json:"allowed_ports"`
+	Direction       AllowedTrafficDirection `json:"direction"` // single or two-way
+	Allowed         bool
 }

+ 5 - 3
models/mqtt.go

@@ -28,9 +28,11 @@ type HostPeerUpdate struct {
 }
 
 type FwRule struct {
-	SrcIP net.IPNet `json:"src_ip"`
-	DstIP net.IPNet `json:"dst_ip"`
-	Allow bool      `json:"allow"`
+	SrcIP           net.IPNet `json:"src_ip"`
+	DstIP           net.IPNet `json:"dst_ip"`
+	AllowedProtocol Protocol  `json:"allowed_protocols"` // tcp, udp, etc.
+	AllowedPorts    []string  `json:"allowed_ports"`
+	Allow           bool      `json:"allow"`
 }
 
 // IngressInfo - struct for ingress info