Преглед на файлове

resolve merge conflicts

abhishek9686 преди 7 месеца
родител
ревизия
41fa0b1bce
променени са 14 файла, в които са добавени 347 реда и са изтрити 53 реда
  1. 1 1
      .github/ISSUE_TEMPLATE/bug-report.yml
  2. 2 0
      Dockerfile
  3. 1 9
      compose/docker-compose.yml
  4. 35 1
      controllers/hosts.go
  5. 0 4
      docker/Caddyfile
  6. 16 10
      logic/acls.go
  7. 0 5
      logic/extpeers.go
  8. 2 4
      logic/nodes.go
  9. 81 3
      logic/peers.go
  10. 0 0
      logic/pro/failover
  11. 4 0
      models/mqtt.go
  12. 137 4
      pro/controllers/failover.go
  13. 1 0
      pro/initialize.go
  14. 67 12
      pro/logic/failover.go

+ 1 - 1
.github/ISSUE_TEMPLATE/bug-report.yml

@@ -2,7 +2,7 @@ name: Bug Report
 description: File a bug report
 title: "[Bug]: "
 labels: ["bug", "triage"]
-assignees: ["ok-john", "0xdcarns", "afeiszli",  "mattkasun"]
+assignees: ["abhishek9686","VishalDalwadi","Aceix","dentadlp"]
 body:
   - type: markdown
     attributes:

+ 2 - 0
Dockerfile

@@ -11,6 +11,8 @@ FROM alpine:3.21.2
 # add a c lib
 # set the working directory
 WORKDIR /root/
+RUN apk update && apk upgrade
+RUN apk add --no-cache sqlite
 RUN mkdir -p /etc/netclient/config
 COPY --from=builder /app/netmaker .
 COPY --from=builder /app/config config

+ 1 - 9
compose/docker-compose.yml

@@ -12,7 +12,7 @@ services:
       - sqldata:/root/data
     environment:
       # config-dependant vars
-      - STUN_SERVERS=stun.${NM_DOMAIN}:3478,stun1.l.google.com:19302,stun2.l.google.com:19302,stun3.l.google.com:19302,stun4.l.google.com:19302
+      - STUN_SERVERS=stun1.l.google.com:19302,stun2.l.google.com:19302,stun3.l.google.com:19302,stun4.l.google.com:19302
       # The domain/host IP indicating the mq broker address
       - BROKER_ENDPOINT=wss://broker.${NM_DOMAIN} # For EMQX broker use `BROKER_ENDPOINT=wss://broker.${NM_DOMAIN}/mqtt`
       # For EMQX broker (uncomment the two lines below)
@@ -39,14 +39,6 @@ services:
     links:
       - "netmaker:api"
     restart: always
-  stun:
-    container_name: stun
-    image: coturn/coturn
-    restart: always
-    ports:
-      - "3478:3478/udp"   # STUN UDP
-    environment:
-     - LISTENING_PORT=3478
 
   caddy:
     image: caddy:2.8.4

+ 35 - 1
controllers/hosts.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"time"
 
 	"github.com/google/uuid"
 	"github.com/gorilla/mux"
@@ -48,6 +49,8 @@ func hostHandlers(r *mux.Router) {
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/fallback/host/{hostid}", Authorize(true, false, "host", http.HandlerFunc(hostUpdateFallback))).
 		Methods(http.MethodPut)
+	r.HandleFunc("/api/v1/host/{hostid}/peer_info", Authorize(true, false, "host", http.HandlerFunc(getHostPeerInfo))).
+		Methods(http.MethodGet)
 	r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))).
 		Methods(http.MethodDelete)
 	r.HandleFunc("/api/v1/auth-register/host", socketHandler)
@@ -232,7 +235,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
 			slog.Error("failed to get node:", "id", node.ID, "error", err)
 			continue
 		}
-		if node.FailedOverBy != uuid.Nil {
+		if node.FailedOverBy != uuid.Nil && r.URL.Query().Get("reset_failovered") == "true" {
 			logic.ResetFailedOverPeer(&node)
 			sendPeerUpdate = true
 		}
@@ -943,6 +946,7 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
 					slog.Info("host sync requested", "user", user, "host", host.ID.String())
 				}
 			}(host)
+			time.Sleep(time.Millisecond * 100)
 		}
 	}()
 
@@ -1017,3 +1021,33 @@ func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
 	}
 	logic.ReturnSuccessResponse(w, r, "deleted hosts data on emqx")
 }
+
+// @Summary     Fetches host peerinfo
+// @Router      /api/host/{hostid}/peer_info [get]
+// @Tags        Hosts
+// @Security    oauth
+// @Param       hostid path string true "Host ID"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     500 {object} models.ErrorResponse
+func getHostPeerInfo(w http.ResponseWriter, r *http.Request) {
+	hostId := mux.Vars(r)["hostid"]
+	var errorResponse = models.ErrorResponse{}
+
+	host, err := logic.GetHost(hostId)
+	if err != nil {
+		slog.Error("failed to retrieve host", "error", err)
+		errorResponse.Code = http.StatusBadRequest
+		errorResponse.Message = err.Error()
+		logic.ReturnErrorResponse(w, r, errorResponse)
+		return
+	}
+	peerInfo, err := logic.GetHostPeerInfo(host)
+	if err != nil {
+		slog.Error("failed to retrieve host peerinfo", "error", err)
+		errorResponse.Code = http.StatusBadRequest
+		errorResponse.Message = err.Error()
+		logic.ReturnErrorResponse(w, r, errorResponse)
+		return
+	}
+	logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info")
+}

+ 0 - 4
docker/Caddyfile

@@ -32,7 +32,3 @@ broker.{$NM_DOMAIN} {
 		}
 	reverse_proxy @ws mq:8883   # For EMQX websockets use `reverse_proxy @ws mq:8083`
 }
-
-https://stun.{$NM_DOMAIN} {
-	reverse_proxy stun:3478
-}

+ 16 - 10
logic/acls.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"maps"
 	"sort"
 	"sync"
 	"time"
@@ -163,6 +164,7 @@ func storeAclInCache(a models.Acl) {
 	aclCacheMutex.Lock()
 	defer aclCacheMutex.Unlock()
 	aclCacheMap[a.ID] = a
+
 }
 
 func removeAclFromCache(a models.Acl) {
@@ -605,6 +607,8 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []mode
 // IsPeerAllowed - checks if peer needs to be added to the interface
 func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 	var nodeId, peerId string
+	peerTags := maps.Clone(peer.Tags)
+	nodeTags := maps.Clone(node.Tags)
 	if node.IsStatic {
 		nodeId = node.StaticNode.ClientID
 		node = node.StaticNode.ConvertToStaticNode()
@@ -617,8 +621,8 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 	} else {
 		peerId = peer.ID.String()
 	}
-	node.Tags[models.TagID(nodeId)] = struct{}{}
-	peer.Tags[models.TagID(peerId)] = struct{}{}
+	nodeTags[models.TagID(nodeId)] = struct{}{}
+	peerTags[models.TagID(peerId)] = struct{}{}
 	if checkDefaultPolicy {
 		// check default policy if all allowed return true
 		defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@@ -627,6 +631,7 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 				return true
 			}
 		}
+
 	}
 	// list device policies
 	policies := listDevicePolicies(models.NetworkID(peer.Network))
@@ -642,7 +647,7 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 		}
 		srcMap = convAclTagToValueMap(policy.Src)
 		dstMap = convAclTagToValueMap(policy.Dst)
-		if checkTagGroupPolicy(srcMap, dstMap, node, peer) {
+		if checkTagGroupPolicy(srcMap, dstMap, node, peer, nodeTags, peerTags) {
 			return true
 		}
 
@@ -650,7 +655,8 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 	return false
 }
 
-func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.Node) bool {
+func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.Node,
+	nodeTags, peerTags map[models.TagID]struct{}) bool {
 	// check for node ID
 	if _, ok := srcMap[node.ID.String()]; ok {
 		if _, ok = dstMap[peer.ID.String()]; ok {
@@ -664,12 +670,12 @@ func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.N
 		}
 	}
 
-	for tagID := range node.Tags {
+	for tagID := range nodeTags {
 		if _, ok := dstMap[tagID.String()]; ok {
 			if _, ok := srcMap["*"]; ok {
 				return true
 			}
-			for tagID := range peer.Tags {
+			for tagID := range peerTags {
 				if _, ok := srcMap[tagID.String()]; ok {
 					return true
 				}
@@ -679,19 +685,19 @@ func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.N
 			if _, ok := dstMap["*"]; ok {
 				return true
 			}
-			for tagID := range peer.Tags {
+			for tagID := range peerTags {
 				if _, ok := dstMap[tagID.String()]; ok {
 					return true
 				}
 			}
 		}
 	}
-	for tagID := range peer.Tags {
+	for tagID := range peerTags {
 		if _, ok := dstMap[tagID.String()]; ok {
 			if _, ok := srcMap["*"]; ok {
 				return true
 			}
-			for tagID := range node.Tags {
+			for tagID := range nodeTags {
 
 				if _, ok := srcMap[tagID.String()]; ok {
 					return true
@@ -702,7 +708,7 @@ func checkTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.N
 			if _, ok := dstMap["*"]; ok {
 				return true
 			}
-			for tagID := range node.Tags {
+			for tagID := range nodeTags {
 				if _, ok := dstMap[tagID.String()]; ok {
 					return true
 				}

+ 0 - 5
logic/extpeers.go

@@ -456,17 +456,12 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
 
 func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 	// fetch user access to static clients via policies
-	defer func() {
-		logger.Log(0, fmt.Sprintf("node.ID: %s, Rules: %+v\n", node.ID, rules))
-	}()
 
 	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 	nodes, _ := GetNetworkNodes(node.Network)
 	nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)
-	//fmt.Printf("=====> NODES: %+v \n\n", nodes)
 	userNodes := GetStaticUserNodesByNetwork(models.NetworkID(node.Network))
-	//fmt.Printf("=====> USER NODES %+v \n\n", userNodes)
 	for _, userNodeI := range userNodes {
 		for _, peer := range nodes {
 			if peer.IsUserNode {

+ 2 - 4
logic/nodes.go

@@ -40,9 +40,7 @@ func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
 }
 func getNodesFromCache() (nodes []models.Node) {
 	nodeCacheMutex.RLock()
-	for _, node := range nodesCacheMap {
-		nodes = append(nodes, node)
-	}
+	nodes = slices.Collect(maps.Values(nodesCacheMap))
 	nodeCacheMutex.RUnlock()
 	return
 }
@@ -141,7 +139,7 @@ func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node
 		defer nodeNetworkCacheMutex.Unlock()
 		return slices.Collect(maps.Values(networkNodes))
 	}
-	var nodes = []models.Node{}
+	var nodes = make([]models.Node, 0, len(allNodes))
 	for i := range allNodes {
 		node := allNodes[i]
 		if node.Network == network {

+ 81 - 3
logic/peers.go

@@ -59,6 +59,80 @@ var (
 	}
 )
 
+// GetHostPeerInfo - fetches required peer info per network
+func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
+	peerInfo := models.HostPeerInfo{
+		NetworkPeerIDs: make(map[models.NetworkID]models.PeerMap),
+	}
+	allNodes, err := GetAllNodes()
+	if err != nil {
+		return peerInfo, err
+	}
+	for _, nodeID := range host.Nodes {
+		nodeID := nodeID
+		node, err := GetNodeByID(nodeID)
+		if err != nil {
+			continue
+		}
+
+		if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE {
+			continue
+		}
+		networkPeersInfo := make(models.PeerMap)
+		defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+
+		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
+		for _, peer := range currentPeers {
+			peer := peer
+			if peer.ID.String() == node.ID.String() {
+				logger.Log(2, "peer update, skipping self")
+				// skip yourself
+				continue
+			}
+
+			peerHost, err := GetHost(peer.HostID.String())
+			if err != nil {
+				logger.Log(1, "no peer host", peer.HostID.String(), err.Error())
+				continue
+			}
+
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
+			if peer.Action != models.NODE_DELETE &&
+				!peer.PendingDelete &&
+				peer.Connected &&
+				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
+				(defaultDevicePolicy.Enabled || allowedToComm) {
+
+				networkPeersInfo[peerHost.PublicKey.String()] = models.IDandAddr{
+					ID:         peer.ID.String(),
+					HostID:     peerHost.ID.String(),
+					Address:    peer.PrimaryAddress(),
+					Name:       peerHost.Name,
+					Network:    peer.Network,
+					ListenPort: peerHost.ListenPort,
+				}
+
+			}
+		}
+		var extPeerIDAndAddrs []models.IDandAddr
+		if node.IsIngressGateway {
+			_, extPeerIDAndAddrs, _, err = GetExtPeers(&node, &node)
+			if err == nil {
+				for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
+					networkPeersInfo[extPeerIdAndAddr.ID] = extPeerIdAndAddr
+				}
+			}
+		}
+		peerInfo.NetworkPeerIDs[models.NetworkID(node.Network)] = networkPeersInfo
+	}
+	return peerInfo, nil
+}
+
 // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
 func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.Node,
 	deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
@@ -294,15 +368,19 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				peerConfig.Endpoint.IP = peer.LocalAddress.IP
 				peerConfig.Endpoint.Port = peerHost.ListenPort
 			}
-			allowedips := GetAllowedIPs(&node, &peer, nil)
-			allowedToComm := IsPeerAllowed(node, peer, false)
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
 			if peer.Action != models.NODE_DELETE &&
 				!peer.PendingDelete &&
 				peer.Connected &&
 				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
 				(defaultDevicePolicy.Enabled || allowedToComm) &&
 				(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
-				peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
+				peerConfig.AllowedIPs = GetAllowedIPs(&node, &peer, nil) // only append allowed IPs if valid connection
 			}
 
 			var nodePeer wgtypes.PeerConfig

+ 0 - 0
logic/pro/failover


+ 4 - 0
models/mqtt.go

@@ -6,6 +6,10 @@ import (
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
+type HostPeerInfo struct {
+	NetworkPeerIDs map[NetworkID]PeerMap `json:"network_peers"`
+}
+
 // HostPeerUpdate - struct for host peer updates
 type HostPeerUpdate struct {
 	Host            Host                  `json:"host"`

+ 137 - 4
pro/controllers/failover.go

@@ -19,7 +19,7 @@ import (
 
 // FailOverHandlers - handlers for FailOver
 func FailOverHandlers(r *mux.Router) {
-	r.HandleFunc("/api/v1/node/{nodeid}/failover", http.HandlerFunc(getfailOver)).
+	r.HandleFunc("/api/v1/node/{nodeid}/failover", controller.Authorize(true, false, "host", http.HandlerFunc(getfailOver))).
 		Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/node/{nodeid}/failover", logic.SecurityCheck(true, http.HandlerFunc(createfailOver))).
 		Methods(http.MethodPost)
@@ -29,6 +29,8 @@ func FailOverHandlers(r *mux.Router) {
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/node/{nodeid}/failover_me", controller.Authorize(true, false, "host", http.HandlerFunc(failOverME))).
 		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/node/{nodeid}/failover_check", controller.Authorize(true, false, "host", http.HandlerFunc(checkfailOverCtx))).
+		Methods(http.MethodGet)
 }
 
 // @Summary     Get failover node
@@ -44,7 +46,6 @@ func getfailOver(w http.ResponseWriter, r *http.Request) {
 	// confirm host exists
 	node, err := logic.GetNodeByID(nodeid)
 	if err != nil {
-		slog.Error("failed to get node:", "node", nodeid, "error", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
@@ -140,6 +141,7 @@ func deletefailOver(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	proLogic.RemoveFailOverFromCache(node.Network)
 	go func() {
 		proLogic.ResetFailOver(&node)
 		mq.PublishPeerUpdate(false)
@@ -265,10 +267,9 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 		)
 		return
 	}
-
 	err = proLogic.SetFailOverCtx(failOverNode, node, peerNode)
 	if err != nil {
-		slog.Error("failed to create failover", "id", node.ID.String(),
+		slog.Debug("failed to create failover", "id", node.ID.String(),
 			"network", node.Network, "error", err)
 		logic.ReturnErrorResponse(
 			w,
@@ -293,3 +294,135 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	logic.ReturnSuccessResponse(w, r, "relayed successfully")
 }
+
+// @Summary     checkfailOverCtx
+// @Router      /api/v1/node/{nodeid}/failover_check [get]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Accept      json
+// @Param       body body models.FailOverMeReq true "Failover request"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	failOverNode, exists := proLogic.FailOverExists(node.Network)
+	if !exists {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf("req-from: %s, failover node doesn't exist in the network", host.Name),
+				"badrequest",
+			),
+		)
+		return
+	}
+	var failOverReq models.FailOverMeReq
+	err = json.NewDecoder(r.Body).Decode(&failOverReq)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	peerNode, err := logic.GetNodeByID(failOverReq.NodeID)
+	if err != nil {
+		slog.Error("peer not found: ", "nodeid", failOverReq.NodeID, "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer not found"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if node.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelayed && node.RelayedBy == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelay && peerNode.RelayedBy == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node acting as internet gw for the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+
+	err = proLogic.CheckFailOverCtx(failOverNode, node, peerNode)
+	if err != nil {
+		slog.Error("failover ctx cannot be set ", "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(fmt.Errorf("failover ctx cannot be set: %v", err), "internal"),
+		)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "failover can be set")
+}

+ 1 - 0
pro/initialize.go

@@ -90,6 +90,7 @@ func InitPro() {
 			slog.Error("no OAuth provider found or not configured, continuing without OAuth")
 		}
 		proLogic.LoadNodeMetricsToCache()
+		proLogic.InitFailOverCache()
 	})
 	logic.ResetFailOver = proLogic.ResetFailOver
 	logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer

+ 67 - 12
pro/logic/failover.go

@@ -13,7 +13,49 @@ import (
 )
 
 var failOverCtxMutex = &sync.RWMutex{}
+var failOverCacheMutex = &sync.RWMutex{}
+var failOverCache = make(map[models.NetworkID]string)
 
+func InitFailOverCache() {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	networks, err := logic.GetNetworks()
+	if err != nil {
+		return
+	}
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		return
+	}
+
+	for _, network := range networks {
+		networkNodes := logic.GetNetworkNodesMemory(allNodes, network.NetID)
+		for _, node := range networkNodes {
+			if node.IsFailOver {
+				failOverCache[models.NetworkID(network.NetID)] = node.ID.String()
+				break
+			}
+		}
+	}
+}
+
+func CheckFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
+	failOverCtxMutex.RLock()
+	defer failOverCtxMutex.RUnlock()
+	if peerNode.FailOverPeers == nil {
+		return nil
+	}
+	if victimNode.FailOverPeers == nil {
+		return nil
+	}
+	_, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()]
+	_, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()]
+	if peerHasFailovered && victimHasFailovered &&
+		victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID {
+		return errors.New("failover ctx is already set")
+	}
+	return nil
+}
 func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
 	failOverCtxMutex.Lock()
 	defer failOverCtxMutex.Unlock()
@@ -23,13 +65,16 @@ func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
 	if victimNode.FailOverPeers == nil {
 		victimNode.FailOverPeers = make(map[string]struct{})
 	}
+	_, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()]
+	_, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()]
+	if peerHasFailovered && victimHasFailovered &&
+		victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID {
+		return errors.New("failover ctx is already set")
+	}
 	peerNode.FailOverPeers[victimNode.ID.String()] = struct{}{}
 	victimNode.FailOverPeers[peerNode.ID.String()] = struct{}{}
 	victimNode.FailedOverBy = failOverNode.ID
 	peerNode.FailedOverBy = failOverNode.ID
-	if err := logic.UpsertNode(&failOverNode); err != nil {
-		return err
-	}
 	if err := logic.UpsertNode(&victimNode); err != nil {
 		return err
 	}
@@ -50,17 +95,26 @@ func GetFailOverNode(network string, allNodes []models.Node) (models.Node, error
 	return models.Node{}, errors.New("auto relay not found")
 }
 
+func RemoveFailOverFromCache(network string) {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	delete(failOverCache, models.NetworkID(network))
+}
+
+func SetFailOverInCache(node models.Node) {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	failOverCache[models.NetworkID(node.Network)] = node.ID.String()
+}
+
 // FailOverExists - checks if failOver exists already in the network
 func FailOverExists(network string) (failOverNode models.Node, exists bool) {
-	nodes, err := logic.GetNetworkNodes(network)
-	if err != nil {
-		return
-	}
-	for _, node := range nodes {
-		if node.IsFailOver {
-			exists = true
-			failOverNode = node
-			return
+	failOverCacheMutex.RLock()
+	defer failOverCacheMutex.RUnlock()
+	if nodeID, ok := failOverCache[models.NetworkID(network)]; ok {
+		failOverNode, err := logic.GetNodeByID(nodeID)
+		if err == nil {
+			return failOverNode, true
 		}
 	}
 	return
@@ -185,5 +239,6 @@ func CreateFailOver(node models.Node) error {
 		slog.Error("failed to upsert node", "node", node.ID.String(), "error", err)
 		return err
 	}
+	SetFailOverInCache(node)
 	return nil
 }