Browse Source

add logic for calculating peers when relayed node is egress gateway

Matthew R Kasun 3 years ago
parent
commit
6ab994bd9e
1 changed files with 39 additions and 24 deletions
  1. 39 24
      logic/peers.go

+ 39 - 24
logic/peers.go

@@ -267,30 +267,10 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 	// handle egress gateway peers
 	if peer.IsEgressGateway == "yes" {
 		//hasGateway = true
-		ranges := peer.EgressGatewayRanges
-		for _, iprange := range ranges { // go through each cidr for egress gateway
-			_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
-			if err != nil {
-				logger.Log(1, "could not parse gateway IP range. Not adding ", iprange)
-				continue // if can't parse CIDR
-			}
-			nodeEndpointArr := strings.Split(peer.Endpoint, ":") // getting the public ip of node
-			if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain endpoint of node
-				logger.Log(2, "egress IP range of ", iprange, " overlaps with ", node.Endpoint, ", omitting")
-				continue // skip adding egress range if overlaps with node's ip
-			}
-			// TODO: Could put in a lot of great logic to avoid conflicts / bad routes
-			if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
-				logger.Log(2, "egress IP range of ", iprange, " overlaps with ", node.LocalAddress, ", omitting")
-				continue // skip adding egress range if overlaps with node's local ip
-			}
-			if err != nil {
-				logger.Log(1, "error encountered when setting egress range", err.Error())
-			} else {
-				allowedips = append(allowedips, *ipnet)
-			}
-		}
+		egressIPs := getEgressIPs(node, peer)
+		allowedips = append(allowedips, egressIPs...)
 	}
+
 	// handle ingress gateway peers
 	if peer.IsIngressGateway == "yes" {
 		extPeers, err := getExtPeers(peer)
@@ -335,6 +315,15 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 				}
 				allowedips = append(allowedips, relayAddr)
 			}
+			relayedNode, err := findNode(ip)
+			if err != nil {
+				logger.Log(1, "unable to find node for relayed address", ip, err.Error())
+				continue
+			}
+			if relayedNode.IsEgressGateway == "yes" {
+				extAllowedIPs := getEgressIPs(node, relayedNode)
+				allowedips = append(allowedips, extAllowedIPs...)
+			}
 		}
 	}
 	return allowedips
@@ -423,7 +412,6 @@ func GetPeerUpdateForRelayedNode(node *models.Node, udppeers map[string]string)
 				allowedips = append(allowedips[:i], allowedips[i+1:]...)
 			}
 		}
-
 	}
 
 	pubkey, err := wgtypes.ParseKey(relay.PublicKey)
@@ -477,3 +465,30 @@ func GetPeerUpdateForRelayedNode(node *models.Node, udppeers map[string]string)
 	peerUpdate.DNS = getPeerDNS(node.Network)
 	return peerUpdate, nil
 }
+
+func getEgressIPs(node, peer *models.Node) []net.IPNet {
+	allowedips := []net.IPNet{}
+	for _, iprange := range peer.EgressGatewayRanges { // go through each cidr for egress gateway
+		_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
+		if err != nil {
+			logger.Log(1, "could not parse gateway IP range. Not adding ", iprange)
+			continue // if can't parse CIDR
+		}
+		nodeEndpointArr := strings.Split(peer.Endpoint, ":") // getting the public ip of node
+		if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain endpoint of node
+			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", node.Endpoint, ", omitting")
+			continue // skip adding egress range if overlaps with node's ip
+		}
+		// TODO: Could put in a lot of great logic to avoid conflicts / bad routes
+		if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
+			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", node.LocalAddress, ", omitting")
+			continue // skip adding egress range if overlaps with node's local ip
+		}
+		if err != nil {
+			logger.Log(1, "error encountered when setting egress range", err.Error())
+		} else {
+			allowedips = append(allowedips, *ipnet)
+		}
+	}
+	return allowedips
+}