فهرست منبع

fix egress updates for relayed nodes

Abhishek Kondur 2 سال پیش
والد
کامیت
e152905cb0
1فایلهای تغییر یافته به همراه46 افزوده شده و 51 حذف شده
  1. 46 51
      logic/peers.go

+ 46 - 51
logic/peers.go

@@ -162,6 +162,38 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					PersistentKeepaliveInterval: &peer.PersistentKeepalive,
 					ReplaceAllowedIPs:           true,
 				}
+				if node.IsIngressGateway || node.IsEgressGateway {
+					if peer.IsIngressGateway {
+						_, extPeerIDAndAddrs, err := getExtPeers(&peer)
+						if err == nil {
+							for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
+								extPeerIdAndAddr := extPeerIdAndAddr
+								nodePeerMap[extPeerIdAndAddr.ID] = models.PeerRouteInfo{
+									PeerAddr: net.IPNet{
+										IP:   net.ParseIP(extPeerIdAndAddr.Address),
+										Mask: getCIDRMaskFromAddr(extPeerIdAndAddr.Address),
+									},
+									PeerKey: extPeerIdAndAddr.ID,
+									Allow:   true,
+									ID:      extPeerIdAndAddr.ID,
+								}
+							}
+						}
+					}
+					if node.IsIngressGateway && peer.IsEgressGateway {
+						hostPeerUpdate.IngressInfo.EgressRanges = append(hostPeerUpdate.IngressInfo.EgressRanges,
+							peer.EgressGatewayRanges...)
+					}
+					nodePeerMap[peerHost.PublicKey.String()] = models.PeerRouteInfo{
+						PeerAddr: net.IPNet{
+							IP:   net.ParseIP(peer.PrimaryAddress()),
+							Mask: getCIDRMaskFromAddr(peer.PrimaryAddress()),
+						},
+						PeerKey: peerHost.PublicKey.String(),
+						Allow:   true,
+						ID:      peer.ID.String(),
+					}
+				}
 				if (node.IsRelayed && node.RelayedBy != peer.ID.String()) || (peer.IsRelayed && peer.RelayedBy != node.ID.String()) {
 					// if node is relayed and peer is not the relay, set remove to true
 					if _, ok := hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()]; ok {
@@ -204,39 +236,6 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
 				}
 
-				if node.IsIngressGateway || node.IsEgressGateway {
-					if peer.IsIngressGateway {
-						_, extPeerIDAndAddrs, err := getExtPeers(&peer)
-						if err == nil {
-							for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
-								extPeerIdAndAddr := extPeerIdAndAddr
-								nodePeerMap[extPeerIdAndAddr.ID] = models.PeerRouteInfo{
-									PeerAddr: net.IPNet{
-										IP:   net.ParseIP(extPeerIdAndAddr.Address),
-										Mask: getCIDRMaskFromAddr(extPeerIdAndAddr.Address),
-									},
-									PeerKey: extPeerIdAndAddr.ID,
-									Allow:   true,
-									ID:      extPeerIdAndAddr.ID,
-								}
-							}
-						}
-					}
-					if node.IsIngressGateway && peer.IsEgressGateway {
-						hostPeerUpdate.IngressInfo.EgressRanges = append(hostPeerUpdate.IngressInfo.EgressRanges,
-							peer.EgressGatewayRanges...)
-					}
-					nodePeerMap[peerHost.PublicKey.String()] = models.PeerRouteInfo{
-						PeerAddr: net.IPNet{
-							IP:   net.ParseIP(peer.PrimaryAddress()),
-							Mask: getCIDRMaskFromAddr(peer.PrimaryAddress()),
-						},
-						PeerKey: peerHost.PublicKey.String(),
-						Allow:   true,
-						ID:      peer.ID.String(),
-					}
-				}
-
 				peerProxyPort := GetProxyListenPort(peerHost)
 				var nodePeer wgtypes.PeerConfig
 				if _, ok := hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()]; !ok {
@@ -584,6 +583,9 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
 	}
 	if node.IsRelayed && node.RelayedBy == peer.ID.String() {
 		allowedips = append(allowedips, getAllowedIpsForRelayed(node, peer)...)
+		if node.IsEgressGateway {
+			allowedips = append(allowedips, getEgressIPs(node)...)
+		}
 	}
 	return allowedips
 }
@@ -649,11 +651,18 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
 		allowedips = append(allowedips, egressIPs...)
 	}
 	if peer.IsRelay {
-		for _, relayed := range peer.RelayedNodes {
-			if node.ID.String() == relayed {
+		for _, relayedNodeID := range peer.RelayedNodes {
+			if node.ID.String() == relayedNodeID {
 				continue
 			}
-			allowed := getRelayedAddresses(relayed)
+			relayedNode, err := GetNodeByID(relayedNodeID)
+			if err != nil {
+				continue
+			}
+			allowed := getRelayedAddresses(relayedNodeID)
+			if relayedNode.IsEgressGateway {
+				allowed = append(allowed, getEgressIPs(&relayedNode)...)
+			}
 			allowedips = append(allowedips, allowed...)
 		}
 	}
@@ -663,23 +672,9 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
 // getAllowedIpsForRelayed - returns the peerConfig for a node relayed by relay
 func getAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNet) {
 	if relayed.RelayedBy != relay.ID.String() {
-		logger.Log(0, "peerUpdateForRelayedByRelay called with invalid parameters")
+		logger.Log(0, "RelayedByRelay called with invalid parameters")
 		return
 	}
-	if relay.Address.IP != nil {
-		relay.Address.Mask = net.CIDRMask(32, 32)
-		allowedIPs = append(allowedIPs, relay.Address)
-	}
-	if relay.Address6.IP != nil {
-		relay.Address6.Mask = net.CIDRMask(128, 128)
-		allowedIPs = append(allowedIPs, relay.Address6)
-	}
-	if relay.IsEgressGateway {
-		allowedIPs = append(allowedIPs, getEgressIPs(relay)...)
-	}
-	if relay.IsIngressGateway {
-		allowedIPs = append(allowedIPs, getIngressIPs(relay)...)
-	}
 	peers, err := GetNetworkNodes(relay.Network)
 	if err != nil {
 		logger.Log(0, "error getting network clients", err.Error())