Browse Source

get relayed allowed ips

Abhishek Kondur 2 năm trước cách đây
mục cha
commit
5be2d420fc
1 tập tin đã thay đổi với 87 bổ sung25 xóa
  1. 87 25
      logic/peers.go

+ 87 - 25
logic/peers.go

@@ -151,10 +151,7 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					//skip yourself
 					continue
 				}
-				if peer.IsRelayed {
-					// skip relayed peers; will be included in relay peer
-					continue
-				}
+
 				peerHost, err := GetHost(peer.HostID.String())
 				if err != nil {
 					logger.Log(1, "no peer host", peer.HostID.String(), err.Error())
@@ -165,6 +162,13 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					PersistentKeepaliveInterval: &peer.PersistentKeepalive,
 					ReplaceAllowedIPs:           true,
 				}
+				if peer.IsRelayed && peer.RelayedBy != node.ID.String() {
+					// skip relayed peers; will be included in relay peer
+					peerConfig.Remove = true
+					hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
+					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
+					continue
+				}
 				if node.IsRelayed && node.RelayedBy != peer.ID.String() {
 					// if node is relayed and peer is not the relay, set remove to true
 					peerConfig.Remove = true
@@ -172,6 +176,7 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
 					continue
 				}
+
 				uselocal := false
 				if host.EndpointIP.String() == peerHost.EndpointIP.String() {
 					// peer is on same network
@@ -195,17 +200,6 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					peerConfig.Endpoint.Port = peerHost.ListenPort
 				}
 				allowedips := GetAllowedIPs(&node, &peer, nil)
-				if peer.IsIngressGateway {
-					for _, entry := range peer.IngressGatewayRange {
-						_, cidr, err := net.ParseCIDR(string(entry))
-						if err == nil {
-							allowedips = append(allowedips, *cidr)
-						}
-					}
-				}
-				if peer.IsEgressGateway {
-					allowedips = append(allowedips, getEgressIPs(&node, &peer)...)
-				}
 				if peer.Action != models.NODE_DELETE &&
 					!peer.PendingDelete &&
 					peer.Connected &&
@@ -267,7 +261,7 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					nodePeer = peerConfig
 				} else {
 					peerAllowedIPs := hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].AllowedIPs
-					peerAllowedIPs = append(peerAllowedIPs, allowedips...)
+					peerAllowedIPs = append(peerAllowedIPs, peerConfig.AllowedIPs...)
 					hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].AllowedIPs = peerAllowedIPs
 					hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].Remove = false
 					hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()][peer.ID.String()] = models.IDandAddr{
@@ -592,14 +586,14 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
 			}
 		}
 	}
+	if node.IsRelayed && node.RelayedBy == peer.ID.String() {
+		allowedips = append(allowedips, getAllowedIpsForRelayed(node, peer)...)
+	}
 	return allowedips
 }
 
-func getEgressIPs(node, peer *models.Node) []net.IPNet {
-	host, err := GetHost(node.HostID.String())
-	if err != nil {
-		logger.Log(0, "error retrieving host for node", node.ID.String(), err.Error())
-	}
+func getEgressIPs(peer *models.Node) []net.IPNet {
+
 	peerHost, err := GetHost(peer.HostID.String())
 	if err != nil {
 		logger.Log(0, "error retrieving host for peer", peer.ID.String(), err.Error())
@@ -619,12 +613,12 @@ func getEgressIPs(node, peer *models.Node) []net.IPNet {
 		}
 		// getting the public ip of node
 		if ipnet.Contains(peerHost.EndpointIP) && !internetGateway { // ensuring egress gateway range does not contain endpoint of node
-			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", host.EndpointIP.String(), ", omitting")
+			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", peerHost.EndpointIP.String(), ", omitting")
 			continue // skip adding egress range if overlaps with node's ip
 		}
 		// TODO: Could put in a lot of great logic to avoid conflicts / bad routes
-		if ipnet.Contains(node.LocalAddress.IP) && !internetGateway { // ensuring egress gateway range does not contain public ip of node
-			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", node.LocalAddress.String(), ", omitting")
+		if ipnet.Contains(peer.LocalAddress.IP) && !internetGateway { // ensuring egress gateway range does not contain public ip of node
+			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", peer.LocalAddress.String(), ", omitting")
 			continue // skip adding egress range if overlaps with node's local ip
 		}
 		if err != nil {
@@ -655,11 +649,14 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
 	// handle egress gateway peers
 	if peer.IsEgressGateway {
 		//hasGateway = true
-		egressIPs := getEgressIPs(node, peer)
+		egressIPs := getEgressIPs(peer)
 		allowedips = append(allowedips, egressIPs...)
 	}
 	if peer.IsRelay {
 		for _, relayed := range peer.RelayedNodes {
+			if node.ID.String() == relayed {
+				continue
+			}
 			allowed := getRelayedAddresses(relayed)
 			allowedips = append(allowedips, allowed...)
 		}
@@ -667,6 +664,71 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
 	return allowedips
 }
 
+// getAllowedIpsForRelayed - returns the peerConfig for a node relayed by relay
+func getAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNet) {
+	if relayed.RelayedBy != relay.ID.String() {
+		logger.Log(0, "peerUpdateForRelayedByRelay called with invalid parameters")
+		return
+	}
+	if relay.Address.IP != nil {
+		relay.Address.Mask = net.CIDRMask(32, 32)
+		allowedIPs = append(allowedIPs, relay.Address)
+	}
+	if relay.Address6.IP != nil {
+		relay.Address6.Mask = net.CIDRMask(128, 128)
+		allowedIPs = append(allowedIPs, relay.Address6)
+	}
+	if relay.IsEgressGateway {
+		allowedIPs = append(allowedIPs, getEgressIPs(relay)...)
+	}
+	if relay.IsIngressGateway {
+		allowedIPs = append(allowedIPs, getIngressIPs(relay)...)
+	}
+	peers, err := GetNetworkNodes(relay.Network)
+	if err != nil {
+		logger.Log(0, "error getting network clients", err.Error())
+		return
+	}
+	for _, peer := range peers {
+		if peer.ID == relayed.ID || peer.ID == relay.ID {
+			continue
+		}
+		if nodeacls.AreNodesAllowed(nodeacls.NetworkID(relayed.Network), nodeacls.NodeID(relayed.ID.String()), nodeacls.NodeID(peer.ID.String())) {
+			allowedIPs = append(allowedIPs, GetAllowedIPs(relayed, &peer, nil)...)
+		}
+	}
+	return
+}
+
+func getIngressIPs(peer *models.Node) []net.IPNet {
+	var ingressIPs []net.IPNet
+	extclients, err := GetNetworkExtClients(peer.Network)
+	if err != nil {
+		return ingressIPs
+	}
+	for _, ec := range extclients {
+		if ec.IngressGatewayID == peer.ID.String() {
+			if ec.Address != "" {
+				ip, cidr, err := net.ParseCIDR(ec.Address)
+				if err != nil {
+					continue
+				}
+				cidr.IP = ip
+				ingressIPs = append(ingressIPs, *cidr)
+			}
+			if ec.Address6 != "" {
+				ip, cidr, err := net.ParseCIDR(ec.Address6)
+				if err != nil {
+					continue
+				}
+				cidr.IP = ip
+				ingressIPs = append(ingressIPs, *cidr)
+			}
+		}
+	}
+	return ingressIPs
+}
+
 func getCIDRMaskFromAddr(addr string) net.IPMask {
 	cidr := net.CIDRMask(32, 32)
 	ipAddr, err := netip.ParseAddr(addr)