Browse Source

cleaning server side peer logic

afeiszli 3 years ago
parent
commit
3e130fe9f8
3 changed files with 95 additions and 20 deletions
  1. 7 0
      logic/nodes.go
  2. 84 11
      logic/peers.go
  3. 4 9
      mq/mq.go

+ 7 - 0
logic/nodes.go

@@ -108,6 +108,13 @@ func GetPeers(node *models.Node) ([]models.Node, error) {
 	return peers, nil
 	return peers, nil
 }
 }
 
 
+// SetIfLeader - gets the peers of a given server node
+func SetPeersIfLeader(node *models.Node) {
+	if IsLeader(node) {
+		setNetworkServerPeers(node)
+	}
+}
+
 // IsLeader - determines if a given server node is a leader
 // IsLeader - determines if a given server node is a leader
 func IsLeader(node *models.Node) bool {
 func IsLeader(node *models.Node) bool {
 	nodes, err := GetSortedNetworkServerNodes(node.Network)
 	nodes, err := GetSortedNetworkServerNodes(node.Network)

+ 84 - 11
logic/peers.go

@@ -7,6 +7,7 @@ import (
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -21,6 +22,18 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 	if err != nil {
 	if err != nil {
 		return models.PeerUpdate{}, err
 		return models.PeerUpdate{}, err
 	}
 	}
+	// begin translating netclient logic
+	/*
+
+
+		Go through netclient code and put below
+
+
+
+	*/
+	// #1 Set Keepalive values: set_keepalive
+	// #2 Set local address: set_local - could be a LOT BETTER and fix some bugs with additional logic
+	// #3 Set allowedips: set_allowedips
 	for _, peer := range currentPeers {
 	for _, peer := range currentPeers {
 		if peer.ID == node.ID {
 		if peer.ID == node.ID {
 			//skip yourself
 			//skip yourself
@@ -32,6 +45,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 		}
 		}
 		if node.Endpoint == peer.Endpoint {
 		if node.Endpoint == peer.Endpoint {
 			//peer is on same network
 			//peer is on same network
+			// set_local
 			if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
 			if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
 				peer.Endpoint = peer.LocalAddress
 				peer.Endpoint = peer.LocalAddress
 			} else {
 			} else {
@@ -43,9 +57,11 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 		if err != nil {
 		if err != nil {
 			return models.PeerUpdate{}, err
 			return models.PeerUpdate{}, err
 		}
 		}
+		// set_allowedips
 		allowedips := GetAllowedIPs(node, &peer)
 		allowedips := GetAllowedIPs(node, &peer)
 		var keepalive time.Duration
 		var keepalive time.Duration
 		if node.PersistentKeepalive != 0 {
 		if node.PersistentKeepalive != 0 {
+			// set_keepalive
 			keepalive, _ = time.ParseDuration(strconv.FormatInt(int64(node.PersistentKeepalive), 10) + "s")
 			keepalive, _ = time.ParseDuration(strconv.FormatInt(int64(node.PersistentKeepalive), 10) + "s")
 		}
 		}
 		var peerData = wgtypes.PeerConfig{
 		var peerData = wgtypes.PeerConfig{
@@ -60,16 +76,73 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 			serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{ID: peer.ID, IsLeader: IsLeader(&peer), Address: peer.Address})
 			serverNodeAddresses = append(serverNodeAddresses, models.ServerAddr{ID: peer.ID, IsLeader: IsLeader(&peer), Address: peer.Address})
 		}
 		}
 	}
 	}
+	if node.IsIngressGateway == "yes" {
+		extPeers, err := getExtPeers(node)
+		if err == nil {
+			peers = append(peers, extPeers...)
+		} else {
+			log.Println("ERROR RETRIEVING EXTERNAL PEERS", err)
+		}
+	}
 	peerUpdate.Network = node.Network
 	peerUpdate.Network = node.Network
 	peerUpdate.Peers = peers
 	peerUpdate.Peers = peers
 	peerUpdate.ServerAddrs = serverNodeAddresses
 	peerUpdate.ServerAddrs = serverNodeAddresses
+	/*
+
+
+		End translation of netclient code
+
+
+	*/
 	return peerUpdate, nil
 	return peerUpdate, nil
 }
 }
 
 
+func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, error) {
+	var peers []wgtypes.PeerConfig
+	extPeers, err := GetExtPeersList(node)
+	if err != nil {
+		return peers, err
+	}
+	for _, extPeer := range extPeers {
+		pubkey, err := wgtypes.ParseKey(extPeer.PublicKey)
+		if err != nil {
+			logger.Log(1, "error parsing ext pub key:", err.Error())
+			continue
+		}
+
+		if node.PublicKey == extPeer.PublicKey {
+			continue
+		}
+
+		var peer wgtypes.PeerConfig
+		var peeraddr = net.IPNet{
+			IP:   net.ParseIP(extPeer.Address),
+			Mask: net.CIDRMask(32, 32),
+		}
+		var allowedips []net.IPNet
+		allowedips = append(allowedips, peeraddr)
+
+		if extPeer.Address6 != "" {
+			var addr6 = net.IPNet{
+				IP:   net.ParseIP(extPeer.Address6),
+				Mask: net.CIDRMask(128, 128),
+			}
+			allowedips = append(allowedips, addr6)
+		}
+		peer = wgtypes.PeerConfig{
+			PublicKey:         pubkey,
+			ReplaceAllowedIPs: true,
+			AllowedIPs:        allowedips,
+		}
+		peers = append(peers, peer)
+	}
+	return peers, nil
+
+}
+
 // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
 // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
 func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 	var allowedips []net.IPNet
 	var allowedips []net.IPNet
-	var gateways []string
 	var peeraddr = net.IPNet{
 	var peeraddr = net.IPNet{
 		IP:   net.ParseIP(peer.Address),
 		IP:   net.ParseIP(peer.Address),
 		Mask: net.CIDRMask(32, 32),
 		Mask: net.CIDRMask(32, 32),
@@ -77,13 +150,13 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 	dualstack := false
 	dualstack := false
 	allowedips = append(allowedips, peeraddr)
 	allowedips = append(allowedips, peeraddr)
 	// handle manually set peers
 	// handle manually set peers
-	for _, allowedIp := range node.AllowedIPs {
+	for _, allowedIp := range peer.AllowedIPs {
 		if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
 		if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
 			nodeEndpointArr := strings.Split(node.Endpoint, ":")
 			nodeEndpointArr := strings.Split(node.Endpoint, ":")
-			if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists..
+			if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != peer.Address { // don't need to add an allowed ip that already exists..
 				allowedips = append(allowedips, *ipnet)
 				allowedips = append(allowedips, *ipnet)
 			}
 			}
-		} else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address {
+		} else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != peer.Address {
 			ipnet := net.IPNet{
 			ipnet := net.IPNet{
 				IP:   net.ParseIP(allowedIp),
 				IP:   net.ParseIP(allowedIp),
 				Mask: net.CIDRMask(32, 32),
 				Mask: net.CIDRMask(32, 32),
@@ -92,25 +165,25 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 		}
 		}
 	}
 	}
 	// handle egress gateway peers
 	// handle egress gateway peers
-	if node.IsEgressGateway == "yes" {
+	if peer.IsEgressGateway == "yes" {
 		//hasGateway = true
 		//hasGateway = true
-		ranges := node.EgressGatewayRanges
+		ranges := peer.EgressGatewayRanges
 		for _, iprange := range ranges { // go through each cidr for egress gateway
 		for _, iprange := range ranges { // go through each cidr for egress gateway
 			_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
 			_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
 			if err != nil {
 			if err != nil {
 				ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
 				ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
 				continue // if can't parse CIDR
 				continue // if can't parse CIDR
 			}
 			}
-			nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node
-			if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node
+			nodeEndpointArr := strings.Split(peer.Endpoint, ":") // getting the public ip of node
+			if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain endpoint of node
 				ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
 				ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
 				continue // skip adding egress range if overlaps with node's ip
 				continue // skip adding egress range if overlaps with node's ip
 			}
 			}
+			// TODO: Could put in a lot of great logic to avoid conflicts / bad routes
 			if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
 			if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
 				ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2)
 				ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2)
 				continue // skip adding egress range if overlaps with node's local ip
 				continue // skip adding egress range if overlaps with node's local ip
 			}
 			}
-			gateways = append(gateways, iprange)
 			if err != nil {
 			if err != nil {
 				log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
 				log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
 			} else {
 			} else {
@@ -118,9 +191,9 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 			}
 			}
 		}
 		}
 	}
 	}
-	if node.Address6 != "" && dualstack {
+	if peer.Address6 != "" && dualstack {
 		var addr6 = net.IPNet{
 		var addr6 = net.IPNet{
-			IP:   net.ParseIP(node.Address6),
+			IP:   net.ParseIP(peer.Address6),
 			Mask: net.CIDRMask(128, 128),
 			Mask: net.CIDRMask(128, 128),
 		}
 		}
 		allowedips = append(allowedips, addr6)
 		allowedips = append(allowedips, addr6)

+ 4 - 9
mq/mq.go

@@ -101,21 +101,16 @@ var UpdateNode mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message)
 
 
 // PublishPeerUpdate --- deterines and publishes a peer update to all the peers of a node
 // PublishPeerUpdate --- deterines and publishes a peer update to all the peers of a node
 func PublishPeerUpdate(newNode *models.Node) error {
 func PublishPeerUpdate(newNode *models.Node) error {
-	if !servercfg.IsMessageQueueBackend() {
-		return nil
-	}
+	// shouldn't need this becaus of runServerPeerUpdate, but test to make sure peers are getting updated
+	// if newNode.IsServer == "yes" {
+	// 	logic.SetPeersIfLeader(newNode)
+	// }
 	networkNodes, err := logic.GetNetworkNodes(newNode.Network)
 	networkNodes, err := logic.GetNetworkNodes(newNode.Network)
 	if err != nil {
 	if err != nil {
 		logger.Log(1, "err getting Network Nodes", err.Error())
 		logger.Log(1, "err getting Network Nodes", err.Error())
 		return err
 		return err
 	}
 	}
 	for _, node := range networkNodes {
 	for _, node := range networkNodes {
-		if node.IsServer == "yes" {
-			if err := logic.ServerUpdate(&node, true); err != nil {
-				logger.Log(1, "failed server peer update on server", node.ID)
-			}
-			continue
-		}
 		peerUpdate, err := logic.GetPeerUpdate(&node)
 		peerUpdate, err := logic.GetPeerUpdate(&node)
 		if err != nil {
 		if err != nil {
 			logger.Log(1, "error getting peer update for node", node.ID, err.Error())
 			logger.Log(1, "error getting peer update for node", node.ID, err.Error())