Browse Source

Merge pull request #1212 from gravitl/bugfix_v0.14.3_relayed_fix

adding server/relay check for endpoint
Matthew R Kasun 3 years ago
parent
commit
ac4ba2e868
5 changed files with 117 additions and 94 deletions
  1. 73 35
      logic/peers.go
  2. 1 30
      logic/relay.go
  3. 12 27
      netclient/local/routes.go
  4. 26 0
      netclient/ncutils/netclientutils.go
  5. 5 2
      netclient/wireguard/common.go

+ 73 - 35
logic/peers.go

@@ -24,13 +24,6 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 	var peerUpdate models.PeerUpdate
 	var peers []wgtypes.PeerConfig
 	var serverNodeAddresses = []models.ServerAddr{}
-	currentPeers, err := GetNetworkNodes(node.Network)
-	if err != nil {
-		return models.PeerUpdate{}, err
-	}
-	if node.IsRelayed == "yes" {
-		return GetPeerUpdateForRelayedNode(node)
-	}
 
 	// udppeers = the peers parsed from the local interface
 	// gives us correct port to reach
@@ -39,18 +32,34 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 		logger.Log(2, errN.Error())
 	}
 
+	currentPeers, err := GetNetworkNodes(node.Network)
+	if err != nil {
+		return models.PeerUpdate{}, err
+	}
+
+	if node.IsRelayed == "yes" {
+		return GetPeerUpdateForRelayedNode(node, udppeers)
+	}
+
 	// #1 Set Keepalive values: set_keepalive
 	// #2 Set local address: set_local - could be a LOT BETTER and fix some bugs with additional logic
 	// #3 Set allowedips: set_allowedips
 	for _, peer := range currentPeers {
+
+		// if the node is not a server, set the endpoint
+		var setEndpoint = !(node.IsServer == "yes")
+
 		if peer.ID == node.ID {
 			//skip yourself
 			continue
 		}
 		if peer.IsRelayed == "yes" {
 			if !(node.IsRelay == "yes" && ncutils.StringSliceContains(node.RelayAddrs, peer.PrimaryAddress())) {
-				//skip -- willl be added to relay
+				//skip -- will be added to relay
 				continue
+			} else if node.IsRelay == "yes" && ncutils.StringSliceContains(node.RelayAddrs, peer.PrimaryAddress()) {
+				// dont set peer endpoint if it's relayed by node
+				setEndpoint = false
 			}
 		}
 		if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID), nodeacls.NodeID(peer.ID)) {
@@ -73,33 +82,41 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 				continue
 			}
 		}
+
+		// set address if setEndpoint is true
+		// otherwise, will get inserted as empty value
+		var address *net.UDPAddr
+
 		// Sets ListenPort to UDP Hole Punching Port assuming:
 		// - UDP Hole Punching is enabled
 		// - udppeers retrieval did not return an error
 		// - the endpoint is valid
-		var setUDPPort = false
-		if peer.UDPHolePunch == "yes" && errN == nil && CheckEndpoint(udppeers[peer.PublicKey]) {
-			endpointstring := udppeers[peer.PublicKey]
-			endpointarr := strings.Split(endpointstring, ":")
-			if len(endpointarr) == 2 {
-				port, err := strconv.Atoi(endpointarr[1])
-				if err == nil {
-					setUDPPort = true
-					peer.ListenPort = int32(port)
+		if setEndpoint {
+
+			var setUDPPort = false
+			if peer.UDPHolePunch == "yes" && errN == nil && CheckEndpoint(udppeers[peer.PublicKey]) {
+				endpointstring := udppeers[peer.PublicKey]
+				endpointarr := strings.Split(endpointstring, ":")
+				if len(endpointarr) == 2 {
+					port, err := strconv.Atoi(endpointarr[1])
+					if err == nil {
+						setUDPPort = true
+						peer.ListenPort = int32(port)
+					}
 				}
 			}
-		}
-		// if udp hole punching is on, but udp hole punching did not set it, use the LocalListenPort instead
-		// or, if port is for some reason zero use the LocalListenPort
-		// but only do this if LocalListenPort is not zero
-		if ((peer.UDPHolePunch == "yes" && !setUDPPort) || peer.ListenPort == 0) && peer.LocalListenPort != 0 {
-			peer.ListenPort = peer.LocalListenPort
-		}
+			// if udp hole punching is on, but udp hole punching did not set it, use the LocalListenPort instead
+			// or, if port is for some reason zero use the LocalListenPort
+			// but only do this if LocalListenPort is not zero
+			if ((peer.UDPHolePunch == "yes" && !setUDPPort) || peer.ListenPort == 0) && peer.LocalListenPort != 0 {
+				peer.ListenPort = peer.LocalListenPort
+			}
 
-		endpoint := peer.Endpoint + ":" + strconv.FormatInt(int64(peer.ListenPort), 10)
-		address, err := net.ResolveUDPAddr("udp", endpoint)
-		if err != nil {
-			return models.PeerUpdate{}, err
+			endpoint := peer.Endpoint + ":" + strconv.FormatInt(int64(peer.ListenPort), 10)
+			address, err = net.ResolveUDPAddr("udp", endpoint)
+			if err != nil {
+				return models.PeerUpdate{}, err
+			}
 		}
 		// set_allowedips
 		allowedips := GetAllowedIPs(node, &peer)
@@ -115,6 +132,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 			AllowedIPs:                  allowedips,
 			PersistentKeepaliveInterval: &keepalive,
 		}
+
 		peers = append(peers, peerData)
 		if peer.IsServer == "yes" {
 			serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{IsLeader: IsLeader(&peer), Address: peer.Address})
@@ -326,7 +344,7 @@ func getPeerDNS(network string) string {
 
 // GetPeerUpdateForRelayedNode - calculates peer update for a relayed node by getting the relay
 // copying the relay node's allowed ips and making appropriate substitutions
-func GetPeerUpdateForRelayedNode(node *models.Node) (models.PeerUpdate, error) {
+func GetPeerUpdateForRelayedNode(node *models.Node, udppeers map[string]string) (models.PeerUpdate, error) {
 	var peerUpdate models.PeerUpdate
 	var peers []wgtypes.PeerConfig
 	var serverNodeAddresses = []models.ServerAddr{}
@@ -336,6 +354,7 @@ func GetPeerUpdateForRelayedNode(node *models.Node) (models.PeerUpdate, error) {
 	if relay == nil {
 		return models.PeerUpdate{}, errors.New("not found")
 	}
+
 	//add relay to lists of allowed ip
 	if relay.Address != "" {
 		relayIP := net.IPNet{
@@ -361,14 +380,14 @@ func GetPeerUpdateForRelayedNode(node *models.Node) (models.PeerUpdate, error) {
 		allowedips = append(allowedips, peer.AllowedIPs...)
 	}
 	//delete any ips not permitted by acl
-	for i, ip := range allowedips {
-		target, err := findNode(ip.IP.String())
+	for i := len(allowedips) - 1; i >= 0; i-- {
+		target, err := findNode(allowedips[i].IP.String())
 		if err != nil {
-			logger.Log(0, "failed to find node for ip", ip.IP.String(), err.Error())
+			logger.Log(0, "failed to find node for ip", allowedips[i].IP.String(), err.Error())
 			continue
 		}
 		if target == nil {
-			logger.Log(0, "failed to find node for ip", ip.IP.String())
+			logger.Log(0, "failed to find node for ip", allowedips[i].IP.String())
 			continue
 		}
 		if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID), nodeacls.NodeID(target.ID)) {
@@ -377,8 +396,8 @@ func GetPeerUpdateForRelayedNode(node *models.Node) (models.PeerUpdate, error) {
 		}
 	}
 	//delete self from allowed ips
-	for i, ip := range allowedips {
-		if ip.IP.String() == node.Address || ip.IP.String() == node.Address6 {
+	for i := len(allowedips) - 1; i >= 0; i-- {
+		if allowedips[i].IP.String() == node.Address || allowedips[i].IP.String() == node.Address6 {
 			allowedips = append(allowedips[:i], allowedips[i+1:]...)
 		}
 	}
@@ -387,6 +406,25 @@ func GetPeerUpdateForRelayedNode(node *models.Node) (models.PeerUpdate, error) {
 	if err != nil {
 		return models.PeerUpdate{}, err
 	}
+	var setUDPPort = false
+	if relay.UDPHolePunch == "yes" && CheckEndpoint(udppeers[relay.PublicKey]) {
+		endpointstring := udppeers[relay.PublicKey]
+		endpointarr := strings.Split(endpointstring, ":")
+		if len(endpointarr) == 2 {
+			port, err := strconv.Atoi(endpointarr[1])
+			if err == nil {
+				setUDPPort = true
+				relay.ListenPort = int32(port)
+			}
+		}
+	}
+	// if udp hole punching is on, but udp hole punching did not set it, use the LocalListenPort instead
+	// or, if port is for some reason zero use the LocalListenPort
+	// but only do this if LocalListenPort is not zero
+	if ((relay.UDPHolePunch == "yes" && !setUDPPort) || relay.ListenPort == 0) && relay.LocalListenPort != 0 {
+		relay.ListenPort = relay.LocalListenPort
+	}
+
 	endpoint := relay.Endpoint + ":" + strconv.FormatInt(int64(relay.ListenPort), 10)
 	address, err := net.ResolveUDPAddr("udp", endpoint)
 	if err != nil {

+ 1 - 30
logic/relay.go

@@ -54,19 +54,13 @@ func SetRelayedNodes(setRelayed bool, networkName string, addrs []string) ([]mod
 	if err != nil {
 		return returnnodes, err
 	}
-	network, err := GetNetworkSettings(networkName)
-	if err != nil {
-		return returnnodes, err
-	}
 	for _, node := range networkNodes {
 		if node.IsServer != "yes" {
 			for _, addr := range addrs {
 				if addr == node.Address || addr == node.Address6 {
 					if setRelayed {
-						node.UDPHolePunch = "no"
 						node.IsRelayed = "yes"
 					} else {
-						node.UDPHolePunch = network.DefaultUDPHolePunch
 						node.IsRelayed = "no"
 					}
 					data, err := json.Marshal(&node)
@@ -82,29 +76,6 @@ func SetRelayedNodes(setRelayed bool, networkName string, addrs []string) ([]mod
 	return returnnodes, nil
 }
 
-// SetNodeIsRelayed - Sets IsRelayed to on or off for relay
-func SetNodeIsRelayed(yesOrno string, id string) (models.Node, error) {
-	node, err := GetNodeByID(id)
-	if err != nil {
-		return node, err
-	}
-	network, err := GetNetworkByNode(&node)
-	if err != nil {
-		return node, err
-	}
-	node.IsRelayed = yesOrno
-	if yesOrno == "yes" {
-		node.UDPHolePunch = "no"
-	} else {
-		node.UDPHolePunch = network.DefaultUDPHolePunch
-	}
-	data, err := json.Marshal(&node)
-	if err != nil {
-		return node, err
-	}
-	return node, database.Insert(node.ID, string(data), database.NODES_TABLE_NAME)
-}
-
 // ValidateRelay - checks if relay is valid
 func ValidateRelay(relay models.RelayRequest) error {
 	var err error
@@ -138,7 +109,7 @@ func DeleteRelay(network, nodeid string) ([]models.Node, models.Node, error) {
 	if err != nil {
 		return returnnodes, models.Node{}, err
 	}
-	_, err = SetRelayedNodes(false, node.Network, node.RelayAddrs)
+	returnnodes, err = SetRelayedNodes(false, node.Network, node.RelayAddrs)
 	if err != nil {
 		return returnnodes, node, err
 	}

+ 12 - 27
netclient/local/routes.go

@@ -11,41 +11,26 @@ import (
 // TODO handle ipv6 in future
 
 // SetPeerRoutes - sets/removes ip routes for each peer on a network
-func SetPeerRoutes(iface string, oldPeers map[string][]net.IPNet, newPeers []wgtypes.PeerConfig) {
+func SetPeerRoutes(iface string, oldPeers map[string]bool, newPeers []wgtypes.PeerConfig) {
 	// traverse through all recieved peers
 	for _, peer := range newPeers {
-		// if pubkey found in existing peers, check against existing peer
-		currPeerAllowedIPs := oldPeers[peer.PublicKey.String()]
-		if currPeerAllowedIPs != nil {
-			// traverse IPs, check to see if old peer contains each IP
-			for _, allowedIP := range peer.AllowedIPs { // compare new ones (if any) to old ones
-				if !ncutils.IPNetSliceContains(currPeerAllowedIPs, allowedIP) {
-					if err := setRoute(iface, &allowedIP, allowedIP.IP.String()); err != nil {
-						logger.Log(1, err.Error())
-					}
-				}
-			}
-			for _, allowedIP := range currPeerAllowedIPs { // compare old ones (if any) to new ones
-				if !ncutils.IPNetSliceContains(peer.AllowedIPs, allowedIP) {
-					if err := deleteRoute(iface, &allowedIP, allowedIP.IP.String()); err != nil {
-						logger.Log(1, err.Error())
-					}
-				}
-			}
-			delete(oldPeers, peer.PublicKey.String()) // remove peer as it was found and processed
-		} else {
-			for _, allowedIP := range peer.AllowedIPs { // add all routes as peer doesn't exist
-				if err := setRoute(iface, &allowedIP, allowedIP.String()); err != nil {
+		for _, allowedIP := range peer.AllowedIPs {
+			if !oldPeers[allowedIP.String()] {
+				if err := setRoute(iface, &allowedIP, allowedIP.IP.String()); err != nil {
 					logger.Log(1, err.Error())
 				}
+			} else {
+				delete(oldPeers, allowedIP.String())
 			}
 		}
 	}
-
 	// traverse through all remaining existing peers
-	for _, allowedIPs := range oldPeers {
-		for _, allowedIP := range allowedIPs {
-			deleteRoute(iface, &allowedIP, allowedIP.IP.String())
+	for i, _ := range oldPeers {
+		ip, err := ncutils.GetIPNetFromString(i)
+		if err != nil {
+			logger.Log(1, err.Error())
+		} else {
+			deleteRoute(iface, &ip, ip.IP.String())
 		}
 	}
 }

+ 26 - 0
netclient/ncutils/netclientutils.go

@@ -19,6 +19,7 @@ import (
 	"strings"
 	"time"
 
+	"github.com/c-robinson/iplib"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -593,3 +594,28 @@ func MakeRandomString(n int) string {
 	}
 	return string(result)
 }
+
+func GetIPNetFromString(ip string) (net.IPNet, error) {
+	var ipnet *net.IPNet
+	var err error
+	// parsing as a CIDR first. If valid CIDR, append
+	if _, cidr, err := net.ParseCIDR(ip); err == nil {
+		ipnet = cidr
+	} else { // parsing as an IP second. If valid IP, check if ipv4 or ipv6, then append
+		if iplib.Version(net.ParseIP(ip)) == 4 {
+			ipnet = &net.IPNet{
+				IP:   net.ParseIP(ip),
+				Mask: net.CIDRMask(32, 32),
+			}
+		} else if iplib.Version(net.ParseIP(ip)) == 6 {
+			ipnet = &net.IPNet{
+				IP:   net.ParseIP(ip),
+				Mask: net.CIDRMask(128, 128),
+			}
+		}
+	}
+	if ipnet == nil {
+		err = errors.New(ip + " is not a valid ip or cidr")
+	}
+	return *ipnet, err
+}

+ 5 - 2
netclient/wireguard/common.go

@@ -28,7 +28,8 @@ const (
 func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error {
 	var devicePeers []wgtypes.Peer
 	var keepalive = node.PersistentKeepalive
-	var oldPeerAllowedIps = make(map[string][]net.IPNet, len(peers))
+	var oldPeerAllowedIps = make(map[string]bool, len(peers))
+
 	var err error
 	devicePeers, err = GetDevicePeers(iface)
 	if err != nil {
@@ -106,7 +107,9 @@ func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error
 						log.Println(output, "error removing peer", currentPeer.PublicKey.String())
 					}
 				}
-				oldPeerAllowedIps[currentPeer.PublicKey.String()] = currentPeer.AllowedIPs
+				for _, ip := range currentPeer.AllowedIPs {
+					oldPeerAllowedIps[ip.String()] = true
+				}
 			}
 		}
 	}