Procházet zdrojové kódy

NET-1252: Restrict inetGws, Relays from getting failedOver (#2937)

* add additional checks to avoid failovers

* add failover defence check on signal handler

* only add check for victim node

* avoid failover reset on pull

* add relayed for failoverme

* misc changes for failover

* remove resetfailoverpeers for InetNode

* add egress route back to allowedip list if relayed is egressGW

* add extclient back to allowedip list if peer is ingressGW

* reset failover on pull

---------

Co-authored-by: Max Ma <[email protected]>
Abhishek K před 1 rokem
rodič
revize
7ff30599ed

+ 1 - 1
logic/extpeers.go

@@ -376,7 +376,7 @@ func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.
 	return newClient, nil
 }
 
-func getExtPeers(node, peer *models.Node) ([]wgtypes.PeerConfig, []models.IDandAddr, []models.EgressNetworkRoutes, error) {
+func GetExtPeers(node, peer *models.Node) ([]wgtypes.PeerConfig, []models.IDandAddr, []models.EgressNetworkRoutes, error) {
 	var peers []wgtypes.PeerConfig
 	var idsAndAddr []models.IDandAddr
 	var egressRoutes []models.EgressNetworkRoutes

+ 2 - 2
logic/peers.go

@@ -283,7 +283,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		var extPeerIDAndAddrs []models.IDandAddr
 		var egressRoutes []models.EgressNetworkRoutes
 		if node.IsIngressGateway {
-			extPeers, extPeerIDAndAddrs, egressRoutes, err = getExtPeers(&node, &node)
+			extPeers, extPeerIDAndAddrs, egressRoutes, err = GetExtPeers(&node, &node)
 			if err == nil {
 				hostPeerUpdate.EgressRoutes = append(hostPeerUpdate.EgressRoutes, egressRoutes...)
 				hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, extPeers...)
@@ -415,7 +415,7 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
 
 	// handle ingress gateway peers
 	if peer.IsIngressGateway {
-		extPeers, _, _, err := getExtPeers(peer, node)
+		extPeers, _, _, err := GetExtPeers(peer, node)
 		if err != nil {
 			logger.Log(2, "could not retrieve ext peers for ", peer.ID.String(), err.Error())
 		}

+ 1 - 14
mq/handlers.go

@@ -201,20 +201,7 @@ func signalPeer(signal models.Signal) {
 		slog.Error("failed to signal, peer host not found", "error", err)
 		return
 	}
-	peerNode, err := logic.GetNodeByID(signal.ToNodeID)
-	if err != nil {
-		slog.Error("failed to signal, node not found", "error", err)
-		return
-	}
-	node, err := logic.GetNodeByID(signal.FromNodeID)
-	if err != nil {
-		slog.Error("failed to signal, peer node not found", "error", err)
-		return
-	}
-	if peerNode.IsIngressGateway || node.IsIngressGateway || peerNode.IsInternetGateway || node.IsInternetGateway {
-		signal.Action = ""
-		return
-	}
+
 	err = HostUpdate(&models.HostUpdate{
 		Action: models.SignalHost,
 		Host:   *peerHost,

+ 16 - 4
pro/controllers/failover.go

@@ -159,12 +159,24 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("peer not found"), "badrequest"))
 		return
 	}
-	if node.IsRelayed || node.IsFailOver {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node is relayed or acting as failover"), "badrequest"))
+	if node.IsFailOver {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node is acting as failover"), "badrequest"))
 		return
 	}
-	if peerNode.IsRelayed || peerNode.IsFailOver {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("peer node is relayed or acting as failover"), "badrequest"))
+	if node.IsRelayed && node.RelayedBy == peerNode.ID.String() {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"))
+		return
+	}
+	if node.IsRelay && peerNode.RelayedBy == node.ID.String() {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"))
+		return
+	}
+	if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node acting as internet gw for the peer node"), "badrequest"))
+		return
+	}
+	if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node using a internet gw by the peer node"), "badrequest"))
 		return
 	}
 

+ 34 - 4
pro/logic/failover.go

@@ -5,15 +5,14 @@ import (
 	"net"
 
 	"github.com/google/uuid"
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"golang.org/x/exp/slog"
 )
 
 func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
-	if victimNode.IsIngressGateway || peerNode.IsIngressGateway || victimNode.IsInternetGateway || peerNode.IsInternetGateway {
-		return nil
-	}
+
 	if peerNode.FailOverPeers == nil {
 		peerNode.FailOverPeers = make(map[string]struct{})
 	}
@@ -125,7 +124,38 @@ func GetFailOverPeerIps(peer, node *models.Node) []net.IPNet {
 			if failOverpeer.IsEgressGateway {
 				allowedips = append(allowedips, logic.GetEgressIPs(&failOverpeer)...)
 			}
-
+			if failOverpeer.IsRelay {
+				for _, id := range failOverpeer.RelayedNodes {
+					rNode, _ := logic.GetNodeByID(id)
+					if rNode.Address.IP != nil {
+						allowed := net.IPNet{
+							IP:   rNode.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						}
+						allowedips = append(allowedips, allowed)
+					}
+					if rNode.Address6.IP != nil {
+						allowed := net.IPNet{
+							IP:   rNode.Address6.IP,
+							Mask: net.CIDRMask(128, 128),
+						}
+						allowedips = append(allowedips, allowed)
+					}
+					if rNode.IsEgressGateway {
+						allowedips = append(allowedips, logic.GetEgressIPs(&rNode)...)
+					}
+				}
+			}
+			// handle ingress gateway peers
+			if failOverpeer.IsIngressGateway {
+				extPeers, _, _, err := logic.GetExtPeers(&failOverpeer, node)
+				if err != nil {
+					logger.Log(2, "could not retrieve ext peers for ", peer.ID.String(), err.Error())
+				}
+				for _, extPeer := range extPeers {
+					allowedips = append(allowedips, extPeer.AllowedIPs...)
+				}
+			}
 		}
 	}
 	return allowedips

+ 5 - 0
pro/logic/nodes.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net"
 
+	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"golang.org/x/exp/slog"
@@ -29,6 +30,7 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool
 	if inetNode.IsRelayed {
 		return fmt.Errorf("node %s is being relayed", inetHost.Name)
 	}
+
 	for _, clientNodeID := range req.InetNodeClientIDs {
 		clientNode, err := logic.GetNodeByID(clientNodeID)
 		if err != nil {
@@ -53,6 +55,9 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool
 				return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name)
 			}
 		}
+		if clientNode.FailedOverBy != uuid.Nil {
+			ResetFailedOverPeer(&clientNode)
+		}
 
 		if clientNode.IsRelayed {
 			return fmt.Errorf("node %s is being relayed", clientHost.Name)

+ 4 - 0
pro/logic/relays.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net"
 
+	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/acls/nodeacls"
@@ -122,6 +123,9 @@ func ValidateRelay(relay models.RelayRequest, update bool) error {
 		if relayedNode.IsInternetGateway {
 			return errors.New("cannot relay an internet gateway (" + relayedNodeID + ")")
 		}
+		if relayedNode.FailedOverBy != uuid.Nil {
+			ResetFailedOverPeer(&relayedNode)
+		}
 	}
 	return err
 }