Browse Source

refactor failover logic to set per-peer. Uses Ingress Gateway logic instead of Relay logic

afeiszli 2 years ago
parent
commit
29ce2fa57a
6 changed files with 135 additions and 104 deletions
  1. 1 1
      ee/initialize.go
  2. 12 24
      ee/logic/failover.go
  3. 89 58
      logic/peers.go
  4. 6 5
      models/metrics.go
  5. 1 0
      models/node.go
  6. 26 16
      mq/handlers.go

+ 1 - 1
ee/initialize.go

@@ -29,7 +29,7 @@ func InitEE() {
 		// == End License Handling ==
 		// == End License Handling ==
 		AddLicenseHooks()
 		AddLicenseHooks()
 	})
 	})
-	logic.EnterpriseFailoverFunc = eelogic.AutoRelay
+	logic.EnterpriseFailoverFunc = eelogic.SetFailover
 }
 }
 
 
 func setControllerLimits() {
 func setControllerLimits() {

+ 12 - 24
ee/logic/failover.go

@@ -1,18 +1,17 @@
 package logic
 package logic
 
 
 import (
 import (
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 )
 )
 
 
-// AutoRelay - finds a suitable relay candidate and creates a relay
-func AutoRelay(nodeToBeRelayed *models.Node) (updateNodes []models.Node, err error) {
-	newRelayer := determineFailoverCandidate(nodeToBeRelayed)
-	if newRelayer != nil {
-		return changeRelayStatus(newRelayer, nodeToBeRelayed)
+// SetFailover - finds a suitable failover candidate and sets it
+func SetFailover(node *models.Node) error {
+	failoverNode := determineFailoverCandidate(node)
+	if failoverNode != nil {
+		return setFailoverNode(failoverNode, node)
 	}
 	}
-	return
+	return nil
 }
 }
 
 
 // determineFailoverCandidate - returns a list of nodes that
 // determineFailoverCandidate - returns a list of nodes that
@@ -55,23 +54,12 @@ func determineFailoverCandidate(nodeToBeRelayed *models.Node) *models.Node {
 	return fastestCandidate
 	return fastestCandidate
 }
 }
 
 
-// changeRelayStatus - changes nodes to relay
-func changeRelayStatus(relayer, nodeToBeRelayed *models.Node) ([]models.Node, error) {
-	var newRelayRequest models.RelayRequest
-
-	if relayer.IsRelay == "yes" {
-		newRelayRequest.RelayAddrs = relayer.RelayAddrs
-	}
-	newRelayRequest.NodeID = relayer.ID
-	newRelayRequest.NetID = relayer.Network
-	newRelayRequest.RelayAddrs = append(newRelayRequest.RelayAddrs, nodeToBeRelayed.PrimaryAddress())
-
-	updatenodes, _, err := logic.CreateRelay(newRelayRequest)
+// setFailoverNode - changes node's failover node
+func setFailoverNode(failoverNode, node *models.Node) error {
+	node.FailoverNode = failoverNode.ID
+	nodeToUpdate, err := logic.GetNodeByID(node.ID)
 	if err != nil {
 	if err != nil {
-		logger.Log(0, "failed to create relay automatically for node", nodeToBeRelayed.Name, "on network", nodeToBeRelayed.Network)
-		return nil, err
+		return err
 	}
 	}
-	logger.Log(0, "created relay automatically for node", nodeToBeRelayed.Name, "on network", nodeToBeRelayed.Network)
-
-	return updatenodes, nil
+	return logic.UpdateNode(&nodeToUpdate, node)
 }
 }

+ 89 - 58
logic/peers.go

@@ -33,6 +33,13 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 	}
 	}
 	var peerMap = make(models.PeerMap)
 	var peerMap = make(models.PeerMap)
 
 
+	var metrics *models.Metrics
+	if servercfg.Is_EE {
+		metrics, _ = GetMetrics(node.ID)
+	}
+	if metrics.NeedsFailover == nil {
+		metrics.NeedsFailover = make(map[string]string)
+	}
 	// udppeers = the peers parsed from the local interface
 	// udppeers = the peers parsed from the local interface
 	// gives us correct port to reach
 	// gives us correct port to reach
 	udppeers, errN := database.GetPeers(node.Network)
 	udppeers, errN := database.GetPeers(node.Network)
@@ -85,7 +92,9 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 		if isP2S && peer.IsHub != "yes" {
 		if isP2S && peer.IsHub != "yes" {
 			continue
 			continue
 		}
 		}
-
+		if metrics.NeedsFailover[peer.ID] != "" {
+			continue
+		}
 		pubkey, err := wgtypes.ParseKey(peer.PublicKey)
 		pubkey, err := wgtypes.ParseKey(peer.PublicKey)
 		if err != nil {
 		if err != nil {
 			return models.PeerUpdate{}, err
 			return models.PeerUpdate{}, err
@@ -139,7 +148,7 @@ func GetPeerUpdate(node *models.Node) (models.PeerUpdate, error) {
 			}
 			}
 		}
 		}
 		// set_allowedips
 		// set_allowedips
-		allowedips := GetAllowedIPs(node, &peer)
+		allowedips := GetAllowedIPs(node, &peer, metrics)
 		var keepalive time.Duration
 		var keepalive time.Duration
 		if node.PersistentKeepalive != 0 {
 		if node.PersistentKeepalive != 0 {
 			// set_keepalive
 			// set_keepalive
@@ -247,64 +256,10 @@ func getExtPeers(node *models.Node) ([]wgtypes.PeerConfig, []models.IDandAddr, e
 }
 }
 
 
 // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
 // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
-func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
+func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet {
 	var allowedips []net.IPNet
 	var allowedips []net.IPNet
 
 
-	if peer.Address != "" {
-		var peeraddr = net.IPNet{
-			IP:   net.ParseIP(peer.Address),
-			Mask: net.CIDRMask(32, 32),
-		}
-		allowedips = append(allowedips, peeraddr)
-	}
-
-	if peer.Address6 != "" {
-		var addr6 = net.IPNet{
-			IP:   net.ParseIP(peer.Address6),
-			Mask: net.CIDRMask(128, 128),
-		}
-		allowedips = append(allowedips, addr6)
-	}
-	// handle manually set peers
-	for _, allowedIp := range peer.AllowedIPs {
-
-		// parsing as a CIDR first. If valid CIDR, append
-		if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
-			nodeEndpointArr := strings.Split(node.Endpoint, ":")
-			if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != peer.Address { // don't need to add an allowed ip that already exists..
-				allowedips = append(allowedips, *ipnet)
-			}
-
-		} else { // parsing as an IP second. If valid IP, check if ipv4 or ipv6, then append
-			if iplib.Version(net.ParseIP(allowedIp)) == 4 && allowedIp != peer.Address {
-				ipnet := net.IPNet{
-					IP:   net.ParseIP(allowedIp),
-					Mask: net.CIDRMask(32, 32),
-				}
-				allowedips = append(allowedips, ipnet)
-			} else if iplib.Version(net.ParseIP(allowedIp)) == 6 && allowedIp != peer.Address6 {
-				ipnet := net.IPNet{
-					IP:   net.ParseIP(allowedIp),
-					Mask: net.CIDRMask(128, 128),
-				}
-				allowedips = append(allowedips, ipnet)
-			}
-		}
-	}
-	// handle egress gateway peers
-	if peer.IsEgressGateway == "yes" {
-		//hasGateway = true
-		egressIPs := getEgressIPs(node, peer)
-		// remove internet gateway if server
-		if node.IsServer == "yes" {
-			for i := len(egressIPs) - 1; i >= 0; i-- {
-				if egressIPs[i].String() == "0.0.0.0/0" || egressIPs[i].String() == "::/0" {
-					egressIPs = append(egressIPs[:i], egressIPs[i+1:]...)
-				}
-			}
-		}
-		allowedips = append(allowedips, egressIPs...)
-	}
+	allowedips = getNodeAllowedIPs(peer, node)
 
 
 	// handle ingress gateway peers
 	// handle ingress gateway peers
 	if peer.IsIngressGateway == "yes" {
 	if peer.IsIngressGateway == "yes" {
@@ -315,6 +270,21 @@ func GetAllowedIPs(node, peer *models.Node) []net.IPNet {
 		for _, extPeer := range extPeers {
 		for _, extPeer := range extPeers {
 			allowedips = append(allowedips, extPeer.AllowedIPs...)
 			allowedips = append(allowedips, extPeer.AllowedIPs...)
 		}
 		}
+		// if node is a failover node, add allowed ips from nodes it is handling
+		if peer.Failover == "yes" && metrics.NeedsFailover != nil {
+			// travers through nodes that need handling
+			for k, v := range metrics.NeedsFailover {
+				// if FailoverNode is me for this node, add allowedips
+				if v == peer.ID {
+					// get original node so we can traverse the allowed ips
+					nodeToFailover, err := GetNodeByID(k)
+					if err == nil {
+						// get all allowedips and append
+						allowedips = append(allowedips, getNodeAllowedIPs(&nodeToFailover, peer)...)
+					}
+				}
+			}
+		}
 	}
 	}
 	// handle relay gateway peers
 	// handle relay gateway peers
 	if peer.IsRelay == "yes" {
 	if peer.IsRelay == "yes" {
@@ -559,3 +529,64 @@ func getEgressIPs(node, peer *models.Node) []net.IPNet {
 	}
 	}
 	return allowedips
 	return allowedips
 }
 }
+
+func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
+	var allowedips = []net.IPNet{}
+
+	if peer.Address != "" {
+		var peeraddr = net.IPNet{
+			IP:   net.ParseIP(peer.Address),
+			Mask: net.CIDRMask(32, 32),
+		}
+		allowedips = append(allowedips, peeraddr)
+	}
+
+	if peer.Address6 != "" {
+		var addr6 = net.IPNet{
+			IP:   net.ParseIP(peer.Address6),
+			Mask: net.CIDRMask(128, 128),
+		}
+		allowedips = append(allowedips, addr6)
+	}
+	// handle manually set peers
+	for _, allowedIp := range peer.AllowedIPs {
+
+		// parsing as a CIDR first. If valid CIDR, append
+		if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
+			nodeEndpointArr := strings.Split(node.Endpoint, ":")
+			if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != peer.Address { // don't need to add an allowed ip that already exists..
+				allowedips = append(allowedips, *ipnet)
+			}
+
+		} else { // parsing as an IP second. If valid IP, check if ipv4 or ipv6, then append
+			if iplib.Version(net.ParseIP(allowedIp)) == 4 && allowedIp != peer.Address {
+				ipnet := net.IPNet{
+					IP:   net.ParseIP(allowedIp),
+					Mask: net.CIDRMask(32, 32),
+				}
+				allowedips = append(allowedips, ipnet)
+			} else if iplib.Version(net.ParseIP(allowedIp)) == 6 && allowedIp != peer.Address6 {
+				ipnet := net.IPNet{
+					IP:   net.ParseIP(allowedIp),
+					Mask: net.CIDRMask(128, 128),
+				}
+				allowedips = append(allowedips, ipnet)
+			}
+		}
+	}
+	// handle egress gateway peers
+	if peer.IsEgressGateway == "yes" {
+		//hasGateway = true
+		egressIPs := getEgressIPs(node, peer)
+		// remove internet gateway if server
+		if node.IsServer == "yes" {
+			for i := len(egressIPs) - 1; i >= 0; i-- {
+				if egressIPs[i].String() == "0.0.0.0/0" || egressIPs[i].String() == "::/0" {
+					egressIPs = append(egressIPs[:i], egressIPs[i+1:]...)
+				}
+			}
+		}
+		allowedips = append(allowedips, egressIPs...)
+	}
+	return allowedips
+}

+ 6 - 5
models/metrics.go

@@ -4,11 +4,12 @@ import "time"
 
 
 // Metrics - metrics struct
 // Metrics - metrics struct
 type Metrics struct {
 type Metrics struct {
-	Network      string            `json:"network" bson:"network" yaml:"network"`
-	NodeID       string            `json:"node_id" bson:"node_id" yaml:"node_id"`
-	NodeName     string            `json:"node_name" bson:"node_name" yaml:"node_name"`
-	IsServer     string            `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"`
-	Connectivity map[string]Metric `json:"connectivity" bson:"connectivity" yaml:"connectivity"`
+	Network       string            `json:"network" bson:"network" yaml:"network"`
+	NodeID        string            `json:"node_id" bson:"node_id" yaml:"node_id"`
+	NodeName      string            `json:"node_name" bson:"node_name" yaml:"node_name"`
+	IsServer      string            `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"`
+	Connectivity  map[string]Metric `json:"connectivity" bson:"connectivity" yaml:"connectivity"`
+	NeedsFailover map[string]string `json:"needsfailover" bson:"needsfailover" yaml:"needsfailover"`
 }
 }
 
 
 // Metric - holds a metric for data between nodes
 // Metric - holds a metric for data between nodes

+ 1 - 0
models/node.go

@@ -82,6 +82,7 @@ type Node struct {
 	EgressGatewayNatEnabled string               `json:"egressgatewaynatenabled" bson:"egressgatewaynatenabled" yaml:"egressgatewaynatenabled"`
 	EgressGatewayNatEnabled string               `json:"egressgatewaynatenabled" bson:"egressgatewaynatenabled" yaml:"egressgatewaynatenabled"`
 	EgressGatewayRequest    EgressGatewayRequest `json:"egressgatewayrequest" bson:"egressgatewayrequest" yaml:"egressgatewayrequest"`
 	EgressGatewayRequest    EgressGatewayRequest `json:"egressgatewayrequest" bson:"egressgatewayrequest" yaml:"egressgatewayrequest"`
 	RelayAddrs              []string             `json:"relayaddrs" bson:"relayaddrs" yaml:"relayaddrs"`
 	RelayAddrs              []string             `json:"relayaddrs" bson:"relayaddrs" yaml:"relayaddrs"`
+	FailoverNode            string               `json:"failovernode" bson:"failovernode" yaml:"failovernode"`
 	IngressGatewayRange     string               `json:"ingressgatewayrange" bson:"ingressgatewayrange" yaml:"ingressgatewayrange"`
 	IngressGatewayRange     string               `json:"ingressgatewayrange" bson:"ingressgatewayrange" yaml:"ingressgatewayrange"`
 	IngressGatewayRange6    string               `json:"ingressgatewayrange6" bson:"ingressgatewayrange6" yaml:"ingressgatewayrange6"`
 	IngressGatewayRange6    string               `json:"ingressgatewayrange6" bson:"ingressgatewayrange6" yaml:"ingressgatewayrange6"`
 	// IsStatic - refers to if the Endpoint is set manually or dynamically
 	// IsStatic - refers to if the Endpoint is set manually or dynamically

+ 26 - 16
mq/handlers.go

@@ -136,23 +136,15 @@ func UpdateMetrics(client mqtt.Client, msg mqtt.Message) {
 			}
 			}
 
 
 			if newMetrics.Connectivity != nil {
 			if newMetrics.Connectivity != nil {
-				hasDisconnection := false
-				for k := range newMetrics.Connectivity {
-					if !newMetrics.Connectivity[k].Connected {
-						hasDisconnection = true
+				err := logic.EnterpriseFailoverFunc.(func(*models.Node) error)(&currentNode)
+				if err != nil {
+					logger.Log(0, "could failed to failover for node", currentNode.Name, "on network", currentNode.Network, "-", err.Error())
+				} else {
+					if err := NodeUpdate(&currentNode); err != nil {
+						logger.Log(1, "error publishing node update to node", currentNode.Name, err.Error())
 					}
 					}
-				}
-				if hasDisconnection {
-					_, err := logic.EnterpriseFailoverFunc.(func(*models.Node) ([]models.Node, error))(&currentNode)
-					if err != nil {
-						logger.Log(0, "could failed to failover for node", currentNode.Name, "on network", currentNode.Network, "-", err.Error())
-					} else {
-						if err := NodeUpdate(&currentNode); err != nil {
-							logger.Log(1, "error publishing node update to node", currentNode.Name, err.Error())
-						}
-						if err := PublishPeerUpdate(&currentNode, true); err != nil {
-							logger.Log(1, "error publishing peer update after auto relay for node", currentNode.Name, err.Error())
-						}
+					if err := PublishPeerUpdate(&currentNode, true); err != nil {
+						logger.Log(1, "error publishing peer update after auto relay for node", currentNode.Name, err.Error())
 					}
 					}
 				}
 				}
 			}
 			}
@@ -217,11 +209,17 @@ func updateNodePeers(currentNode *models.Node) {
 }
 }
 
 
 func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) {
 func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) {
+	if newMetrics.NeedsFailover == nil {
+		newMetrics.NeedsFailover = make(map[string]string)
+	}
 	oldMetrics, err := logic.GetMetrics(currentNode.ID)
 	oldMetrics, err := logic.GetMetrics(currentNode.ID)
 	if err != nil {
 	if err != nil {
 		logger.Log(1, "error finding old metrics for node", currentNode.ID, currentNode.Name)
 		logger.Log(1, "error finding old metrics for node", currentNode.ID, currentNode.Name)
 		return
 		return
 	}
 	}
+	if oldMetrics.NeedsFailover == nil {
+		oldMetrics.NeedsFailover = make(map[string]string)
+	}
 
 
 	var attachedClients []models.ExtClient
 	var attachedClients []models.ExtClient
 	if currentNode.IsIngressGateway == "yes" {
 	if currentNode.IsIngressGateway == "yes" {
@@ -254,6 +252,18 @@ func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) {
 		newMetrics.Connectivity[k] = currMetric
 		newMetrics.Connectivity[k] = currMetric
 	}
 	}
 
 
+	// add nodes that need failover
+	nodes, err := logic.GetNetworkNodes(currentNode.Network)
+	if err != nil {
+		logger.Log(0, "failed to retrieve nodes while updating metrics")
+		return
+	}
+	for _, node := range nodes {
+		if !newMetrics.Connectivity[node.ID].Connected && node.Connected == "yes" {
+			newMetrics.NeedsFailover[node.ID] = node.FailoverNode
+		}
+	}
+
 	for k := range oldMetrics.Connectivity { // cleanup any left over data, self healing
 	for k := range oldMetrics.Connectivity { // cleanup any left over data, self healing
 		delete(newMetrics.Connectivity, k)
 		delete(newMetrics.Connectivity, k)
 	}
 	}