Sfoglia il codice sorgente

Merge branch 'NET-1941' of https://github.com/gravitl/netmaker into NET-1911-latest

abhishek9686 7 mesi fa
parent
commit
f0813d102d

+ 2 - 0
Dockerfile

@@ -11,6 +11,8 @@ FROM alpine:3.21.2
 # add a c lib
 # set the working directory
 WORKDIR /root/
+RUN apk update && apk upgrade
+RUN apk add --no-cache sqlite
 RUN mkdir -p /etc/netclient/config
 COPY --from=builder /app/netmaker .
 COPY --from=builder /app/config config

+ 1 - 9
compose/docker-compose.yml

@@ -12,7 +12,7 @@ services:
       - sqldata:/root/data
     environment:
       # config-dependant vars
-      - STUN_SERVERS=stun.${NM_DOMAIN}:3478,stun1.l.google.com:19302,stun2.l.google.com:19302,stun3.l.google.com:19302,stun4.l.google.com:19302
+      - STUN_SERVERS=stun1.l.google.com:19302,stun2.l.google.com:19302,stun3.l.google.com:19302,stun4.l.google.com:19302
       # The domain/host IP indicating the mq broker address
       - BROKER_ENDPOINT=wss://broker.${NM_DOMAIN} # For EMQX broker use `BROKER_ENDPOINT=wss://broker.${NM_DOMAIN}/mqtt`
       # For EMQX broker (uncomment the two lines below)
@@ -39,14 +39,6 @@ services:
     links:
       - "netmaker:api"
     restart: always
-  stun:
-    container_name: stun
-    image: coturn/coturn
-    restart: always
-    ports:
-      - "3478:3478/udp"   # STUN UDP
-    environment:
-     - LISTENING_PORT=3478
 
   caddy:
     image: caddy:2.8.4

+ 34 - 0
controllers/hosts.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"time"
 
 	"github.com/google/uuid"
 	"github.com/gorilla/mux"
@@ -48,6 +49,8 @@ func hostHandlers(r *mux.Router) {
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/fallback/host/{hostid}", Authorize(true, false, "host", http.HandlerFunc(hostUpdateFallback))).
 		Methods(http.MethodPut)
+	r.HandleFunc("/api/v1/host/{hostid}/peer_info", Authorize(true, false, "host", http.HandlerFunc(getHostPeerInfo))).
+		Methods(http.MethodGet)
 	r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))).
 		Methods(http.MethodDelete)
 	r.HandleFunc("/api/v1/auth-register/host", socketHandler)
@@ -943,6 +946,7 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
 					slog.Info("host sync requested", "user", user, "host", host.ID.String())
 				}
 			}(host)
+			time.Sleep(time.Millisecond * 100)
 		}
 	}()
 
@@ -1017,3 +1021,33 @@ func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
 	}
 	logic.ReturnSuccessResponse(w, r, "deleted hosts data on emqx")
 }
+
+// @Summary     Fetches host peerinfo
+// @Router      /api/host/{hostid}/peer_info [get]
+// @Tags        Hosts
+// @Security    oauth
+// @Param       hostid path string true "Host ID"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     500 {object} models.ErrorResponse
+func getHostPeerInfo(w http.ResponseWriter, r *http.Request) {
+	hostId := mux.Vars(r)["hostid"]
+	var errorResponse = models.ErrorResponse{}
+
+	host, err := logic.GetHost(hostId)
+	if err != nil {
+		slog.Error("failed to retrieve host", "error", err)
+		errorResponse.Code = http.StatusBadRequest
+		errorResponse.Message = err.Error()
+		logic.ReturnErrorResponse(w, r, errorResponse)
+		return
+	}
+	peerInfo, err := logic.GetHostPeerInfo(host)
+	if err != nil {
+		slog.Error("failed to retrieve host peerinfo", "error", err)
+		errorResponse.Code = http.StatusBadRequest
+		errorResponse.Message = err.Error()
+		logic.ReturnErrorResponse(w, r, errorResponse)
+		return
+	}
+	logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info")
+}

+ 0 - 4
docker/Caddyfile

@@ -32,7 +32,3 @@ broker.{$NM_DOMAIN} {
 		}
 	reverse_proxy @ws mq:8883   # For EMQX websockets use `reverse_proxy @ws mq:8083`
 }
-
-https://stun.{$NM_DOMAIN} {
-	reverse_proxy stun:3478
-}

+ 81 - 5
logic/acls.go

@@ -14,10 +14,19 @@ import (
 )
 
 var (
-	aclCacheMutex = &sync.RWMutex{}
-	aclCacheMap   = make(map[string]models.Acl)
+	aclCacheMutex      = &sync.RWMutex{}
+	aclCacheMap        = make(map[string]models.Acl)
+	aclNetCacheMutex   = &sync.RWMutex{}
+	aclNetworkCacheMap = make(map[models.NetworkID]AclNetInfo)
 )
 
+type AclNetInfo struct {
+	DevicePolices       []models.Acl
+	UserPolicies        []models.Acl
+	DefaultDevicePolicy models.Acl
+	DefaultUserPolicy   models.Acl
+}
+
 func MigrateAclPolicies() {
 	acls := ListAcls()
 	for _, acl := range acls {
@@ -31,6 +40,34 @@ func MigrateAclPolicies() {
 
 }
 
+func loadNetworkAclsIntoCache() {
+	aclNetCacheMutex.Lock()
+	defer aclNetCacheMutex.Unlock()
+	aclNetworkCacheMap = make(map[models.NetworkID]AclNetInfo)
+	acls := ListAcls()
+	for _, acl := range acls {
+		aclNetInfo := aclNetworkCacheMap[acl.NetworkID]
+		if acl.RuleType == models.DevicePolicy {
+			aclNetInfo.DevicePolices = append(aclNetInfo.DevicePolices, acl)
+		} else {
+			aclNetInfo.UserPolicies = append(aclNetInfo.UserPolicies, acl)
+		}
+		aclNetworkCacheMap[acl.NetworkID] = aclNetInfo
+	}
+	for netID, aclNetInfo := range aclNetworkCacheMap {
+		defaultDevicePolicy, err := GetDefaultPolicy(models.NetworkID(netID), models.DevicePolicy)
+		if err == nil {
+			aclNetInfo.DefaultDevicePolicy = defaultDevicePolicy
+			aclNetworkCacheMap[netID] = aclNetInfo
+		}
+		defaultUserPolicy, err := GetDefaultPolicy(models.NetworkID(netID), models.UserPolicy)
+		if err == nil {
+			aclNetInfo.DefaultUserPolicy = defaultUserPolicy
+			aclNetworkCacheMap[netID] = aclNetInfo
+		}
+	}
+}
+
 // CreateDefaultAclNetworkPolicies - create default acl network policies
 func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
 	if netID.String() == "" {
@@ -161,13 +198,20 @@ func listAclFromCache() (acls []models.Acl) {
 
 func storeAclInCache(a models.Acl) {
 	aclCacheMutex.Lock()
-	defer aclCacheMutex.Unlock()
+	defer func() {
+		aclCacheMutex.Unlock()
+		go loadNetworkAclsIntoCache()
+	}()
 	aclCacheMap[a.ID] = a
+
 }
 
 func removeAclFromCache(a models.Acl) {
 	aclCacheMutex.Lock()
-	defer aclCacheMutex.Unlock()
+	defer func() {
+		aclCacheMutex.Unlock()
+		go loadNetworkAclsIntoCache()
+	}()
 	delete(aclCacheMap, a.ID)
 }
 
@@ -514,6 +558,37 @@ func listPoliciesOfUser(user models.User, netID models.NetworkID) []models.Acl {
 	return userAcls
 }
 
+func GetDefaultPolicyFromNetCache(netID models.NetworkID, ruleType models.AclPolicyType) models.Acl {
+	aclNetCacheMutex.RLock()
+	defer aclNetCacheMutex.RUnlock()
+	if aclNetInfo, ok := aclNetworkCacheMap[netID]; ok {
+		if ruleType == models.DevicePolicy {
+			return aclNetInfo.DefaultDevicePolicy
+		} else {
+			return aclNetInfo.DefaultUserPolicy
+		}
+	}
+	return models.Acl{}
+}
+
+func listPolicesFromNetCache(netID models.NetworkID, ruleType models.AclPolicyType) []models.Acl {
+	aclNetCacheMutex.RLock()
+	if aclNetInfo, ok := aclNetworkCacheMap[netID]; ok {
+		if ruleType == models.DevicePolicy {
+			aclNetCacheMutex.RUnlock()
+			return aclNetInfo.DevicePolices
+		} else {
+			aclNetCacheMutex.RUnlock()
+			return aclNetInfo.UserPolicies
+		}
+	}
+	aclNetCacheMutex.RUnlock()
+	if ruleType == models.DevicePolicy {
+		return listDevicePolicies(netID)
+	}
+	return listUserPolicies(netID)
+}
+
 // listDevicePolicies - lists all device policies in a network
 func listDevicePolicies(netID models.NetworkID) []models.Acl {
 	allAcls := ListAcls()
@@ -627,9 +702,10 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 				return true
 			}
 		}
+
 	}
 	// list device policies
-	policies := listDevicePolicies(models.NetworkID(peer.Network))
+	policies := listPolicesFromNetCache(models.NetworkID(node.Network), models.DevicePolicy)
 	srcMap := make(map[string]struct{})
 	dstMap := make(map[string]struct{})
 	defer func() {

+ 2 - 4
logic/nodes.go

@@ -40,9 +40,7 @@ func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
 }
 func getNodesFromCache() (nodes []models.Node) {
 	nodeCacheMutex.RLock()
-	for _, node := range nodesCacheMap {
-		nodes = append(nodes, node)
-	}
+	nodes = slices.Collect(maps.Values(nodesCacheMap))
 	nodeCacheMutex.RUnlock()
 	return
 }
@@ -141,7 +139,7 @@ func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node
 		defer nodeNetworkCacheMutex.Unlock()
 		return slices.Collect(maps.Values(networkNodes))
 	}
-	var nodes = []models.Node{}
+	var nodes = make([]models.Node, 0, len(allNodes))
 	for i := range allNodes {
 		node := allNodes[i]
 		if node.Network == network {

+ 81 - 3
logic/peers.go

@@ -59,6 +59,80 @@ var (
 	}
 )
 
+// GetHostPeerInfo - fetches required peer info per network
+func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
+	peerInfo := models.HostPeerInfo{
+		NetworkPeerIDs: make(map[models.NetworkID]models.PeerMap),
+	}
+	allNodes, err := GetAllNodes()
+	if err != nil {
+		return peerInfo, err
+	}
+	for _, nodeID := range host.Nodes {
+		nodeID := nodeID
+		node, err := GetNodeByID(nodeID)
+		if err != nil {
+			continue
+		}
+
+		if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE {
+			continue
+		}
+		networkPeersInfo := make(models.PeerMap)
+		defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+
+		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
+		for _, peer := range currentPeers {
+			peer := peer
+			if peer.ID.String() == node.ID.String() {
+				logger.Log(2, "peer update, skipping self")
+				// skip yourself
+				continue
+			}
+
+			peerHost, err := GetHost(peer.HostID.String())
+			if err != nil {
+				logger.Log(1, "no peer host", peer.HostID.String(), err.Error())
+				continue
+			}
+
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
+			if peer.Action != models.NODE_DELETE &&
+				!peer.PendingDelete &&
+				peer.Connected &&
+				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
+				(defaultDevicePolicy.Enabled || allowedToComm) {
+
+				networkPeersInfo[peerHost.PublicKey.String()] = models.IDandAddr{
+					ID:         peer.ID.String(),
+					HostID:     peerHost.ID.String(),
+					Address:    peer.PrimaryAddress(),
+					Name:       peerHost.Name,
+					Network:    peer.Network,
+					ListenPort: peerHost.ListenPort,
+				}
+
+			}
+		}
+		var extPeerIDAndAddrs []models.IDandAddr
+		if node.IsIngressGateway {
+			_, extPeerIDAndAddrs, _, err = GetExtPeers(&node, &node)
+			if err == nil {
+				for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
+					networkPeersInfo[extPeerIdAndAddr.ID] = extPeerIdAndAddr
+				}
+			}
+		}
+		peerInfo.NetworkPeerIDs[models.NetworkID(node.Network)] = networkPeersInfo
+	}
+	return peerInfo, nil
+}
+
 // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
 func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.Node,
 	deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
@@ -294,15 +368,19 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				peerConfig.Endpoint.IP = peer.LocalAddress.IP
 				peerConfig.Endpoint.Port = peerHost.ListenPort
 			}
-			allowedips := GetAllowedIPs(&node, &peer, nil)
-			allowedToComm := IsPeerAllowed(node, peer, false)
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
 			if peer.Action != models.NODE_DELETE &&
 				!peer.PendingDelete &&
 				peer.Connected &&
 				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
 				(defaultDevicePolicy.Enabled || allowedToComm) &&
 				(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
-				peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
+				peerConfig.AllowedIPs = GetAllowedIPs(&node, &peer, nil) // only append allowed IPs if valid connection
 			}
 
 			var nodePeer wgtypes.PeerConfig

+ 0 - 0
logic/pro/failover


+ 4 - 0
models/mqtt.go

@@ -6,6 +6,10 @@ import (
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
+type HostPeerInfo struct {
+	NetworkPeerIDs map[NetworkID]PeerMap `json:"network_peers"`
+}
+
 // HostPeerUpdate - struct for host peer updates
 type HostPeerUpdate struct {
 	Host            Host                  `json:"host"`

+ 137 - 3
pro/controllers/failover.go

@@ -19,7 +19,7 @@ import (
 
 // FailOverHandlers - handlers for FailOver
 func FailOverHandlers(r *mux.Router) {
-	r.HandleFunc("/api/v1/node/{nodeid}/failover", http.HandlerFunc(getfailOver)).
+	r.HandleFunc("/api/v1/node/{nodeid}/failover", controller.Authorize(true, false, "host", http.HandlerFunc(getfailOver))).
 		Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/node/{nodeid}/failover", logic.SecurityCheck(true, http.HandlerFunc(createfailOver))).
 		Methods(http.MethodPost)
@@ -29,6 +29,8 @@ func FailOverHandlers(r *mux.Router) {
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/node/{nodeid}/failover_me", controller.Authorize(true, false, "host", http.HandlerFunc(failOverME))).
 		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/node/{nodeid}/failover_check", controller.Authorize(true, false, "host", http.HandlerFunc(checkfailOverCtx))).
+		Methods(http.MethodGet)
 }
 
 // @Summary     Get failover node
@@ -44,7 +46,6 @@ func getfailOver(w http.ResponseWriter, r *http.Request) {
 	// confirm host exists
 	node, err := logic.GetNodeByID(nodeid)
 	if err != nil {
-		slog.Error("failed to get node:", "node", nodeid, "error", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
@@ -140,6 +141,7 @@ func deletefailOver(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	proLogic.RemoveFailOverFromCache(node.Network)
 	go func() {
 		proLogic.ResetFailOver(&node)
 		mq.PublishPeerUpdate(false)
@@ -268,7 +270,7 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 
 	err = proLogic.SetFailOverCtx(failOverNode, node, peerNode)
 	if err != nil {
-		slog.Error("failed to create failover", "id", node.ID.String(),
+		slog.Debug("failed to create failover", "id", node.ID.String(),
 			"network", node.Network, "error", err)
 		logic.ReturnErrorResponse(
 			w,
@@ -293,3 +295,135 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	logic.ReturnSuccessResponse(w, r, "relayed successfully")
 }
+
+// @Summary     Failover me
+// @Router      /api/v1/node/{nodeid}/failover_check [get]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Accept      json
+// @Param       body body models.FailOverMeReq true "Failover request"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	failOverNode, exists := proLogic.FailOverExists(node.Network)
+	if !exists {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf("req-from: %s, failover node doesn't exist in the network", host.Name),
+				"badrequest",
+			),
+		)
+		return
+	}
+	var failOverReq models.FailOverMeReq
+	err = json.NewDecoder(r.Body).Decode(&failOverReq)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	peerNode, err := logic.GetNodeByID(failOverReq.NodeID)
+	if err != nil {
+		slog.Error("peer not found: ", "nodeid", failOverReq.NodeID, "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer not found"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if node.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelayed && node.RelayedBy == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelay && peerNode.RelayedBy == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node acting as internet gw for the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+
+	err = proLogic.CheckFailOverCtx(failOverNode, node, peerNode)
+	if err != nil {
+		slog.Error("failover ctx cannot be set ", "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(fmt.Errorf("failover ctx cannot be set: %v", err), "internal"),
+		)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "failover can be set")
+}

+ 1 - 0
pro/initialize.go

@@ -90,6 +90,7 @@ func InitPro() {
 			slog.Error("no OAuth provider found or not configured, continuing without OAuth")
 		}
 		proLogic.LoadNodeMetricsToCache()
+		proLogic.InitFailOverCache()
 	})
 	logic.ResetFailOver = proLogic.ResetFailOver
 	logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer

+ 67 - 12
pro/logic/failover.go

@@ -13,7 +13,49 @@ import (
 )
 
 var failOverCtxMutex = &sync.RWMutex{}
+var failOverCacheMutex = &sync.RWMutex{}
+var failOverCache = make(map[models.NetworkID]string)
 
+func InitFailOverCache() {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	networks, err := logic.GetNetworks()
+	if err != nil {
+		return
+	}
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		return
+	}
+
+	for _, network := range networks {
+		networkNodes := logic.GetNetworkNodesMemory(allNodes, network.NetID)
+		for _, node := range networkNodes {
+			if node.IsFailOver {
+				failOverCache[models.NetworkID(network.NetID)] = node.ID.String()
+				break
+			}
+		}
+	}
+}
+
+func CheckFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
+	failOverCtxMutex.RLock()
+	defer failOverCtxMutex.RUnlock()
+	if peerNode.FailOverPeers == nil {
+		return nil
+	}
+	if victimNode.FailOverPeers == nil {
+		return nil
+	}
+	_, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()]
+	_, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()]
+	if peerHasFailovered && victimHasFailovered &&
+		victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID {
+		return errors.New("failover ctx is already set")
+	}
+	return nil
+}
 func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
 	failOverCtxMutex.Lock()
 	defer failOverCtxMutex.Unlock()
@@ -23,13 +65,16 @@ func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
 	if victimNode.FailOverPeers == nil {
 		victimNode.FailOverPeers = make(map[string]struct{})
 	}
+	_, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()]
+	_, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()]
+	if peerHasFailovered && victimHasFailovered &&
+		victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID {
+		return errors.New("failover ctx is already set")
+	}
 	peerNode.FailOverPeers[victimNode.ID.String()] = struct{}{}
 	victimNode.FailOverPeers[peerNode.ID.String()] = struct{}{}
 	victimNode.FailedOverBy = failOverNode.ID
 	peerNode.FailedOverBy = failOverNode.ID
-	if err := logic.UpsertNode(&failOverNode); err != nil {
-		return err
-	}
 	if err := logic.UpsertNode(&victimNode); err != nil {
 		return err
 	}
@@ -50,17 +95,26 @@ func GetFailOverNode(network string, allNodes []models.Node) (models.Node, error
 	return models.Node{}, errors.New("auto relay not found")
 }
 
+func RemoveFailOverFromCache(network string) {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	delete(failOverCache, models.NetworkID(network))
+}
+
+func SetFailOverInCache(node models.Node) {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	failOverCache[models.NetworkID(node.Network)] = node.ID.String()
+}
+
 // FailOverExists - checks if failOver exists already in the network
 func FailOverExists(network string) (failOverNode models.Node, exists bool) {
-	nodes, err := logic.GetNetworkNodes(network)
-	if err != nil {
-		return
-	}
-	for _, node := range nodes {
-		if node.IsFailOver {
-			exists = true
-			failOverNode = node
-			return
+	failOverCacheMutex.RLock()
+	defer failOverCacheMutex.RUnlock()
+	if nodeID, ok := failOverCache[models.NetworkID(network)]; ok {
+		failOverNode, err := logic.GetNodeByID(nodeID)
+		if err == nil {
+			return failOverNode, true
 		}
 	}
 	return
@@ -185,5 +239,6 @@ func CreateFailOver(node models.Node) error {
 		slog.Error("failed to upsert node", "node", node.ID.String(), "error", err)
 		return err
 	}
+	SetFailOverInCache(node)
 	return nil
 }