Prechádzať zdrojové kódy

Merge branch 'develop' of https://github.com/gravitl/netmaker into NET-1956

abhishek9686 6 mesiacov pred
rodič
commit
c4b39e476c

+ 1 - 1
.github/ISSUE_TEMPLATE/bug-report.yml

@@ -2,7 +2,7 @@ name: Bug Report
 description: File a bug report
 title: "[Bug]: "
 labels: ["bug", "triage"]
-assignees: ["ok-john", "0xdcarns", "afeiszli",  "mattkasun"]
+assignees: ["abhishek9686","VishalDalwadi","Aceix","dentadlp"]
 body:
   - type: markdown
     attributes:

+ 2 - 0
Dockerfile

@@ -11,6 +11,8 @@ FROM alpine:3.21.2
 # add a c lib
 # set the working directory
 WORKDIR /root/
+RUN apk update && apk upgrade
+RUN apk add --no-cache sqlite
 RUN mkdir -p /etc/netclient/config
 COPY --from=builder /app/netmaker .
 COPY --from=builder /app/config config

+ 36 - 2
controllers/hosts.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"net/http"
+	"time"
 
 	"github.com/google/uuid"
 	"github.com/gorilla/mux"
@@ -48,6 +49,8 @@ func hostHandlers(r *mux.Router) {
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/fallback/host/{hostid}", Authorize(true, false, "host", http.HandlerFunc(hostUpdateFallback))).
 		Methods(http.MethodPut)
+	r.HandleFunc("/api/v1/host/{hostid}/peer_info", Authorize(true, false, "host", http.HandlerFunc(getHostPeerInfo))).
+		Methods(http.MethodGet)
 	r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))).
 		Methods(http.MethodDelete)
 	r.HandleFunc("/api/v1/auth-register/host", socketHandler)
@@ -232,7 +235,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
 			slog.Error("failed to get node:", "id", node.ID, "error", err)
 			continue
 		}
-		if node.FailedOverBy != uuid.Nil {
+		if node.FailedOverBy != uuid.Nil && r.URL.Query().Get("reset_failovered") == "true" {
 			logic.ResetFailedOverPeer(&node)
 			sendPeerUpdate = true
 		}
@@ -368,7 +371,7 @@ func hostUpdateFallback(w http.ResponseWriter, r *http.Request) {
 	var hostUpdate models.HostUpdate
 	err = json.NewDecoder(r.Body).Decode(&hostUpdate)
 	if err != nil {
-		slog.Error("failed to update a host:", "user", r.Header.Get("user"), "error", err.Error())
+		slog.Error("failed to update a host:", "user", r.Header.Get("user"), "error", err.Error(), "host", currentHost.Name)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
@@ -943,6 +946,7 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
 					slog.Info("host sync requested", "user", user, "host", host.ID.String())
 				}
 			}(host)
+			time.Sleep(time.Millisecond * 100)
 		}
 	}()
 
@@ -1017,3 +1021,33 @@ func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
 	}
 	logic.ReturnSuccessResponse(w, r, "deleted hosts data on emqx")
 }
+
+// @Summary     Fetches host peerinfo
+// @Router      /api/host/{hostid}/peer_info [get]
+// @Tags        Hosts
+// @Security    oauth
+// @Param       hostid path string true "Host ID"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     500 {object} models.ErrorResponse
+func getHostPeerInfo(w http.ResponseWriter, r *http.Request) {
+	hostId := mux.Vars(r)["hostid"]
+	var errorResponse = models.ErrorResponse{}
+
+	host, err := logic.GetHost(hostId)
+	if err != nil {
+		slog.Error("failed to retrieve host", "error", err)
+		errorResponse.Code = http.StatusBadRequest
+		errorResponse.Message = err.Error()
+		logic.ReturnErrorResponse(w, r, errorResponse)
+		return
+	}
+	peerInfo, err := logic.GetHostPeerInfo(host)
+	if err != nil {
+		slog.Error("failed to retrieve host peerinfo", "error", err)
+		errorResponse.Code = http.StatusBadRequest
+		errorResponse.Message = err.Error()
+		logic.ReturnErrorResponse(w, r, errorResponse)
+		return
+	}
+	logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info")
+}

+ 1 - 0
controllers/network.go

@@ -464,6 +464,7 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype))
 		return
 	}
+	go logic.UnlinkNetworkAndTagsFromEnrollmentKeys(network, true)
 	go logic.DeleteNetworkRoles(network)
 	go logic.DeleteDefaultNetworkPolicies(models.NetworkID(network))
 	//delete network from allocated ip map

+ 9 - 0
controllers/user.go

@@ -258,6 +258,15 @@ func getUserV1(w http.ResponseWriter, r *http.Request) {
 	resp := models.ReturnUserWithRolesAndGroups{
 		ReturnUser:   user,
 		PlatformRole: userRoleTemplate,
+		UserGroups:   map[models.UserGroupID]models.UserGroup{},
+	}
+	for gId := range user.UserGroups {
+		grp, err := logic.GetUserGroup(gId)
+		if err != nil {
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
+		}
+		resp.UserGroups[gId] = grp
 	}
 	logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched)
 	logic.ReturnSuccessResponseWithJson(w, r, resp, "fetched user with role info")

+ 25 - 12
logic/acls.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"maps"
 	"sort"
 	"sync"
 	"time"
@@ -16,6 +17,7 @@ import (
 var (
 	aclCacheMutex = &sync.RWMutex{}
 	aclCacheMap   = make(map[string]models.Acl)
+	aclTagsMutex  = &sync.RWMutex{}
 )
 
 func MigrateAclPolicies() {
@@ -163,6 +165,7 @@ func storeAclInCache(a models.Acl) {
 	aclCacheMutex.Lock()
 	defer aclCacheMutex.Unlock()
 	aclCacheMap[a.ID] = a
+
 }
 
 func removeAclFromCache(a models.Acl) {
@@ -574,6 +577,10 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 	if peer.IsStatic {
 		peer = peer.StaticNode.ConvertToStaticNode()
 	}
+	aclTagsMutex.RLock()
+	peerTags := maps.Clone(peer.Tags)
+	nodeTags := maps.Clone(node.Tags)
+	aclTagsMutex.RUnlock()
 	if checkDefaultPolicy {
 		// check default policy if all allowed return true
 		defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@@ -582,6 +589,7 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 				return true
 			}
 		}
+
 	}
 	// list device policies
 	policies := listDevicePolicies(models.NetworkID(peer.Network))
@@ -597,12 +605,12 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 		}
 		srcMap = convAclTagToValueMap(policy.Src)
 		dstMap = convAclTagToValueMap(policy.Dst)
-		for tagID := range node.Tags {
+		for tagID := range nodeTags {
 			if _, ok := dstMap[tagID.String()]; ok {
 				if _, ok := srcMap["*"]; ok {
 					return true
 				}
-				for tagID := range peer.Tags {
+				for tagID := range peerTags {
 					if _, ok := srcMap[tagID.String()]; ok {
 						return true
 					}
@@ -612,19 +620,20 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 				if _, ok := dstMap["*"]; ok {
 					return true
 				}
-				for tagID := range peer.Tags {
+				for tagID := range peerTags {
 					if _, ok := dstMap[tagID.String()]; ok {
 						return true
 					}
 				}
 			}
 		}
-		for tagID := range peer.Tags {
+
+		for tagID := range peerTags {
 			if _, ok := dstMap[tagID.String()]; ok {
 				if _, ok := srcMap["*"]; ok {
 					return true
 				}
-				for tagID := range node.Tags {
+				for tagID := range nodeTags {
 
 					if _, ok := srcMap[tagID.String()]; ok {
 						return true
@@ -635,7 +644,7 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool {
 				if _, ok := dstMap["*"]; ok {
 					return true
 				}
-				for tagID := range node.Tags {
+				for tagID := range nodeTags {
 					if _, ok := dstMap[tagID.String()]; ok {
 						return true
 					}
@@ -654,6 +663,10 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 	if peer.IsStatic {
 		peer = peer.StaticNode.ConvertToStaticNode()
 	}
+	aclTagsMutex.RLock()
+	peerTags := maps.Clone(peer.Tags)
+	nodeTags := maps.Clone(node.Tags)
+	aclTagsMutex.RUnlock()
 	if checkDefaultPolicy {
 		// check default policy if all allowed return true
 		defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
@@ -678,7 +691,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 		}
 		srcMap = convAclTagToValueMap(policy.Src)
 		dstMap = convAclTagToValueMap(policy.Dst)
-		for tagID := range node.Tags {
+		for tagID := range nodeTags {
 			allowed := false
 			if _, ok := dstMap[tagID.String()]; policy.AllowedDirection == models.TrafficDirectionBi && ok {
 				if _, ok := srcMap["*"]; ok {
@@ -686,7 +699,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 					allowedPolicies = append(allowedPolicies, policy)
 					break
 				}
-				for tagID := range peer.Tags {
+				for tagID := range peerTags {
 					if _, ok := srcMap[tagID.String()]; ok {
 						allowed = true
 						break
@@ -703,7 +716,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 					allowedPolicies = append(allowedPolicies, policy)
 					break
 				}
-				for tagID := range peer.Tags {
+				for tagID := range peerTags {
 					if _, ok := dstMap[tagID.String()]; ok {
 						allowed = true
 						break
@@ -715,7 +728,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 				break
 			}
 		}
-		for tagID := range peer.Tags {
+		for tagID := range peerTags {
 			allowed := false
 			if _, ok := dstMap[tagID.String()]; ok {
 				if _, ok := srcMap["*"]; ok {
@@ -723,7 +736,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 					allowedPolicies = append(allowedPolicies, policy)
 					break
 				}
-				for tagID := range node.Tags {
+				for tagID := range nodeTags {
 
 					if _, ok := srcMap[tagID.String()]; ok {
 						allowed = true
@@ -742,7 +755,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
 					allowedPolicies = append(allowedPolicies, policy)
 					break
 				}
-				for tagID := range node.Tags {
+				for tagID := range nodeTags {
 					if _, ok := dstMap[tagID.String()]; ok {
 						allowed = true
 						break

+ 57 - 1
logic/enrollmentkey.go

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"strings"
 	"sync"
 	"time"
 
@@ -120,7 +121,6 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID)
 }
 
 // GetAllEnrollmentKeys - fetches all enrollment keys from DB
-// TODO drop double pointer
 func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) {
 	currentKeys, err := getEnrollmentKeysMap()
 	if err != nil {
@@ -335,3 +335,59 @@ func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) {
 
 	}
 }
+
+func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error {
+	keys, err := GetAllEnrollmentKeys()
+	if err != nil {
+		return fmt.Errorf("failed to retrieve keys: %w", err)
+	}
+
+	var errs []error
+	for _, key := range keys {
+		newNetworks := []string{}
+		newTags := []models.TagID{}
+		update := false
+
+		// Check and update networks
+		for _, net := range key.Networks {
+			if net == network {
+				update = true
+				continue
+			}
+			newNetworks = append(newNetworks, net)
+		}
+
+		// Check and update tags
+		for _, tag := range key.Groups {
+			tagParts := strings.Split(tag.String(), ".")
+			if len(tagParts) == 0 {
+				continue
+			}
+			tagNetwork := tagParts[0]
+			if tagNetwork == network {
+				update = true
+				continue
+			}
+			newTags = append(newTags, tag)
+		}
+
+		if update && len(newNetworks) == 0 && delete {
+			if err := DeleteEnrollmentKey(key.Value, true); err != nil {
+				errs = append(errs, fmt.Errorf("failed to delete key %s: %w", key.Value, err))
+			}
+			continue
+		}
+		if update {
+			key.Networks = newNetworks
+			key.Groups = newTags
+			if err := upsertEnrollmentKey(&key); err != nil {
+				errs = append(errs, fmt.Errorf("failed to update key %s: %w", key.Value, err))
+			}
+		}
+	}
+
+	if len(errs) > 0 {
+		return fmt.Errorf("errors unlinking network/tags from keys: %v", errs)
+	}
+	return nil
+}

+ 0 - 5
logic/extpeers.go

@@ -456,17 +456,12 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
 
 func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
 	// fetch user access to static clients via policies
-	defer func() {
-		logger.Log(0, fmt.Sprintf("node.ID: %s, Rules: %+v\n", node.ID, rules))
-	}()
 
 	defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
 	defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
 	nodes, _ := GetNetworkNodes(node.Network)
 	nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)
-	//fmt.Printf("=====> NODES: %+v \n\n", nodes)
 	userNodes := GetStaticUserNodesByNetwork(models.NetworkID(node.Network))
-	//fmt.Printf("=====> USER NODES %+v \n\n", userNodes)
 	for _, userNodeI := range userNodes {
 		for _, peer := range nodes {
 			if peer.IsUserNode {

+ 2 - 4
logic/nodes.go

@@ -40,9 +40,7 @@ func getNodeFromCache(nodeID string) (node models.Node, ok bool) {
 }
 func getNodesFromCache() (nodes []models.Node) {
 	nodeCacheMutex.RLock()
-	for _, node := range nodesCacheMap {
-		nodes = append(nodes, node)
-	}
+	nodes = slices.Collect(maps.Values(nodesCacheMap))
 	nodeCacheMutex.RUnlock()
 	return
 }
@@ -141,7 +139,7 @@ func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node
 		defer nodeNetworkCacheMutex.Unlock()
 		return slices.Collect(maps.Values(networkNodes))
 	}
-	var nodes = []models.Node{}
+	var nodes = make([]models.Node, 0, len(allNodes))
 	for i := range allNodes {
 		node := allNodes[i]
 		if node.Network == network {

+ 81 - 3
logic/peers.go

@@ -59,6 +59,80 @@ var (
 	}
 )
 
+// GetHostPeerInfo - fetches required peer info per network
+func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
+	peerInfo := models.HostPeerInfo{
+		NetworkPeerIDs: make(map[models.NetworkID]models.PeerMap),
+	}
+	allNodes, err := GetAllNodes()
+	if err != nil {
+		return peerInfo, err
+	}
+	for _, nodeID := range host.Nodes {
+		nodeID := nodeID
+		node, err := GetNodeByID(nodeID)
+		if err != nil {
+			continue
+		}
+
+		if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE {
+			continue
+		}
+		networkPeersInfo := make(models.PeerMap)
+		defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+
+		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
+		for _, peer := range currentPeers {
+			peer := peer
+			if peer.ID.String() == node.ID.String() {
+				logger.Log(2, "peer update, skipping self")
+				// skip yourself
+				continue
+			}
+
+			peerHost, err := GetHost(peer.HostID.String())
+			if err != nil {
+				logger.Log(1, "no peer host", peer.HostID.String(), err.Error())
+				continue
+			}
+
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
+			if peer.Action != models.NODE_DELETE &&
+				!peer.PendingDelete &&
+				peer.Connected &&
+				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
+				(defaultDevicePolicy.Enabled || allowedToComm) {
+
+				networkPeersInfo[peerHost.PublicKey.String()] = models.IDandAddr{
+					ID:         peer.ID.String(),
+					HostID:     peerHost.ID.String(),
+					Address:    peer.PrimaryAddress(),
+					Name:       peerHost.Name,
+					Network:    peer.Network,
+					ListenPort: peerHost.ListenPort,
+				}
+
+			}
+		}
+		var extPeerIDAndAddrs []models.IDandAddr
+		if node.IsIngressGateway {
+			_, extPeerIDAndAddrs, _, err = GetExtPeers(&node, &node)
+			if err == nil {
+				for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
+					networkPeersInfo[extPeerIdAndAddr.ID] = extPeerIdAndAddr
+				}
+			}
+		}
+		peerInfo.NetworkPeerIDs[models.NetworkID(node.Network)] = networkPeersInfo
+	}
+	return peerInfo, nil
+}
+
 // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
 func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.Node,
 	deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
@@ -295,15 +369,19 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				peerConfig.Endpoint.IP = peer.LocalAddress.IP
 				peerConfig.Endpoint.Port = peerHost.ListenPort
 			}
-			allowedips := GetAllowedIPs(&node, &peer, nil)
-			allowedToComm := IsPeerAllowed(node, peer, false)
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
 			if peer.Action != models.NODE_DELETE &&
 				!peer.PendingDelete &&
 				peer.Connected &&
 				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
 				(defaultDevicePolicy.Enabled || allowedToComm) &&
 				(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
-				peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
+				peerConfig.AllowedIPs = GetAllowedIPs(&node, &peer, nil) // only append allowed IPs if valid connection
 			}
 
 			var nodePeer wgtypes.PeerConfig

+ 0 - 0
logic/pro/failover


+ 1 - 0
logic/user_mgmt.go

@@ -60,6 +60,7 @@ var DeleteNetworkRoles = func(netID string) {}
 var CreateDefaultNetworkRolesAndGroups = func(netID models.NetworkID) {}
 var CreateDefaultUserPolicies = func(netID models.NetworkID) {}
 var GetUserGroupsInNetwork = func(netID models.NetworkID) (networkGrps map[models.UserGroupID]models.UserGroup) { return }
+var GetUserGroup = func(groupId models.UserGroupID) (userGrps models.UserGroup, err error) { return }
 var AddGlobalNetRolesToAdmins = func(u models.User) {}
 
 // GetRole - fetches role template by id

+ 4 - 0
models/mqtt.go

@@ -6,6 +6,10 @@ import (
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
+type HostPeerInfo struct {
+	NetworkPeerIDs map[NetworkID]PeerMap `json:"network_peers"`
+}
+
 // HostPeerUpdate - struct for host peer updates
 type HostPeerUpdate struct {
 	Host            Host                  `json:"host"`

+ 12 - 11
models/structs.go

@@ -32,17 +32,18 @@ type IngressGwUsers struct {
 
 // UserRemoteGws - struct to hold user's remote gws
 type UserRemoteGws struct {
-	GwID              string    `json:"remote_access_gw_id"`
-	GWName            string    `json:"gw_name"`
-	Network           string    `json:"network"`
-	Connected         bool      `json:"connected"`
-	IsInternetGateway bool      `json:"is_internet_gateway"`
-	GwClient          ExtClient `json:"gw_client"`
-	GwPeerPublicKey   string    `json:"gw_peer_public_key"`
-	GwListenPort      int       `json:"gw_listen_port"`
-	Metadata          string    `json:"metadata"`
-	AllowedEndpoints  []string  `json:"allowed_endpoints"`
-	NetworkAddresses  []string  `json:"network_addresses"`
+	GwID              string     `json:"remote_access_gw_id"`
+	GWName            string     `json:"gw_name"`
+	Network           string     `json:"network"`
+	Connected         bool       `json:"connected"`
+	IsInternetGateway bool       `json:"is_internet_gateway"`
+	GwClient          ExtClient  `json:"gw_client"`
+	GwPeerPublicKey   string     `json:"gw_peer_public_key"`
+	GwListenPort      int        `json:"gw_listen_port"`
+	Metadata          string     `json:"metadata"`
+	AllowedEndpoints  []string   `json:"allowed_endpoints"`
+	NetworkAddresses  []string   `json:"network_addresses"`
+	Status            NodeStatus `json:"status"`
 }
 
 // UserRAGs - struct for user access gws

+ 1 - 0
models/user_mgmt.go

@@ -159,6 +159,7 @@ type User struct {
 type ReturnUserWithRolesAndGroups struct {
 	ReturnUser
 	PlatformRole UserRolePermissionTemplate `json:"platform_role"`
+	UserGroups   map[UserGroupID]UserGroup  `json:"user_group_ids"`
 }
 
 // ReturnUser - return user struct

+ 137 - 4
pro/controllers/failover.go

@@ -19,7 +19,7 @@ import (
 
 // FailOverHandlers - handlers for FailOver
 func FailOverHandlers(r *mux.Router) {
-	r.HandleFunc("/api/v1/node/{nodeid}/failover", http.HandlerFunc(getfailOver)).
+	r.HandleFunc("/api/v1/node/{nodeid}/failover", controller.Authorize(true, false, "host", http.HandlerFunc(getfailOver))).
 		Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/node/{nodeid}/failover", logic.SecurityCheck(true, http.HandlerFunc(createfailOver))).
 		Methods(http.MethodPost)
@@ -29,6 +29,8 @@ func FailOverHandlers(r *mux.Router) {
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/node/{nodeid}/failover_me", controller.Authorize(true, false, "host", http.HandlerFunc(failOverME))).
 		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/node/{nodeid}/failover_check", controller.Authorize(true, false, "host", http.HandlerFunc(checkfailOverCtx))).
+		Methods(http.MethodGet)
 }
 
 // @Summary     Get failover node
@@ -44,7 +46,6 @@ func getfailOver(w http.ResponseWriter, r *http.Request) {
 	// confirm host exists
 	node, err := logic.GetNodeByID(nodeid)
 	if err != nil {
-		slog.Error("failed to get node:", "node", nodeid, "error", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
@@ -140,6 +141,7 @@ func deletefailOver(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	proLogic.RemoveFailOverFromCache(node.Network)
 	go func() {
 		proLogic.ResetFailOver(&node)
 		mq.PublishPeerUpdate(false)
@@ -265,10 +267,9 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 		)
 		return
 	}
-
 	err = proLogic.SetFailOverCtx(failOverNode, node, peerNode)
 	if err != nil {
-		slog.Error("failed to create failover", "id", node.ID.String(),
+		slog.Debug("failed to create failover", "id", node.ID.String(),
 			"network", node.Network, "error", err)
 		logic.ReturnErrorResponse(
 			w,
@@ -293,3 +294,135 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	logic.ReturnSuccessResponse(w, r, "relayed successfully")
 }
+
+// @Summary     checkfailOverCtx
+// @Router      /api/v1/node/{nodeid}/failover_check [get]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Accept      json
+// @Param       body body models.FailOverMeReq true "Failover request"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	failOverNode, exists := proLogic.FailOverExists(node.Network)
+	if !exists {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf("req-from: %s, failover node doesn't exist in the network", host.Name),
+				"badrequest",
+			),
+		)
+		return
+	}
+	var failOverReq models.FailOverMeReq
+	err = json.NewDecoder(r.Body).Decode(&failOverReq)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	peerNode, err := logic.GetNodeByID(failOverReq.NodeID)
+	if err != nil {
+		slog.Error("peer not found: ", "nodeid", failOverReq.NodeID, "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer not found"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if node.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsFailOver {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as failover"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelayed && node.RelayedBy == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelay && peerNode.RelayedBy == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node acting as internet gw for the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+
+	err = proLogic.CheckFailOverCtx(failOverNode, node, peerNode)
+	if err != nil {
+		slog.Error("failover ctx cannot be set ", "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(fmt.Errorf("failover ctx cannot be set: %v", err), "internal"),
+		)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "failover can be set")
+}

+ 10 - 0
pro/controllers/users.go

@@ -1102,6 +1102,10 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 				slog.Error("failed to get node network", "error", err)
 				continue
 			}
+			nodesWithStatus := logic.AddStatusToNodes([]models.Node{node})
+			if len(nodesWithStatus) > 0 {
+				node = nodesWithStatus[0]
+			}
 
 			gws := userGws[node.Network]
 			extClient.AllowedIPs = logic.GetExtclientAllowedIPs(extClient)
@@ -1117,6 +1121,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 				Metadata:          node.Metadata,
 				AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
 				NetworkAddresses:  []string{network.AddressRange, network.AddressRange6},
+				Status:            node.Status,
 			})
 			userGws[node.Network] = gws
 			delete(userGwNodes, node.ID.String())
@@ -1138,6 +1143,10 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 		if err != nil {
 			continue
 		}
+		nodesWithStatus := logic.AddStatusToNodes([]models.Node{node})
+		if len(nodesWithStatus) > 0 {
+			node = nodesWithStatus[0]
+		}
 		network, err := logic.GetNetwork(node.Network)
 		if err != nil {
 			slog.Error("failed to get node network", "error", err)
@@ -1154,6 +1163,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 			Metadata:          node.Metadata,
 			AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
 			NetworkAddresses:  []string{network.AddressRange, network.AddressRange6},
+			Status:            node.Status,
 		})
 		userGws[node.Network] = gws
 	}

+ 2 - 0
pro/initialize.go

@@ -90,6 +90,7 @@ func InitPro() {
 			slog.Error("no OAuth provider found or not configured, continuing without OAuth")
 		}
 		proLogic.LoadNodeMetricsToCache()
+		proLogic.InitFailOverCache()
 	})
 	logic.ResetFailOver = proLogic.ResetFailOver
 	logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer
@@ -132,6 +133,7 @@ func InitPro() {
 	logic.IntialiseGroups = proLogic.UserGroupsInit
 	logic.AddGlobalNetRolesToAdmins = proLogic.AddGlobalNetRolesToAdmins
 	logic.GetUserGroupsInNetwork = proLogic.GetUserGroupsInNetwork
+	logic.GetUserGroup = proLogic.GetUserGroup
 	logic.GetNodeStatus = proLogic.GetNodeStatus
 }
 

+ 67 - 12
pro/logic/failover.go

@@ -13,7 +13,49 @@ import (
 )
 
 var failOverCtxMutex = &sync.RWMutex{}
+var failOverCacheMutex = &sync.RWMutex{}
+var failOverCache = make(map[models.NetworkID]string)
 
+func InitFailOverCache() {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	networks, err := logic.GetNetworks()
+	if err != nil {
+		return
+	}
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		return
+	}
+
+	for _, network := range networks {
+		networkNodes := logic.GetNetworkNodesMemory(allNodes, network.NetID)
+		for _, node := range networkNodes {
+			if node.IsFailOver {
+				failOverCache[models.NetworkID(network.NetID)] = node.ID.String()
+				break
+			}
+		}
+	}
+}
+
+func CheckFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
+	failOverCtxMutex.RLock()
+	defer failOverCtxMutex.RUnlock()
+	if peerNode.FailOverPeers == nil {
+		return nil
+	}
+	if victimNode.FailOverPeers == nil {
+		return nil
+	}
+	_, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()]
+	_, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()]
+	if peerHasFailovered && victimHasFailovered &&
+		victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID {
+		return errors.New("failover ctx is already set")
+	}
+	return nil
+}
 func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
 	failOverCtxMutex.Lock()
 	defer failOverCtxMutex.Unlock()
@@ -23,13 +65,16 @@ func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error {
 	if victimNode.FailOverPeers == nil {
 		victimNode.FailOverPeers = make(map[string]struct{})
 	}
+	_, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()]
+	_, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()]
+	if peerHasFailovered && victimHasFailovered &&
+		victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID {
+		return errors.New("failover ctx is already set")
+	}
 	peerNode.FailOverPeers[victimNode.ID.String()] = struct{}{}
 	victimNode.FailOverPeers[peerNode.ID.String()] = struct{}{}
 	victimNode.FailedOverBy = failOverNode.ID
 	peerNode.FailedOverBy = failOverNode.ID
-	if err := logic.UpsertNode(&failOverNode); err != nil {
-		return err
-	}
 	if err := logic.UpsertNode(&victimNode); err != nil {
 		return err
 	}
@@ -50,17 +95,26 @@ func GetFailOverNode(network string, allNodes []models.Node) (models.Node, error
 	return models.Node{}, errors.New("auto relay not found")
 }
 
+func RemoveFailOverFromCache(network string) {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	delete(failOverCache, models.NetworkID(network))
+}
+
+func SetFailOverInCache(node models.Node) {
+	failOverCacheMutex.Lock()
+	defer failOverCacheMutex.Unlock()
+	failOverCache[models.NetworkID(node.Network)] = node.ID.String()
+}
+
 // FailOverExists - checks if failOver exists already in the network
 func FailOverExists(network string) (failOverNode models.Node, exists bool) {
-	nodes, err := logic.GetNetworkNodes(network)
-	if err != nil {
-		return
-	}
-	for _, node := range nodes {
-		if node.IsFailOver {
-			exists = true
-			failOverNode = node
-			return
+	failOverCacheMutex.RLock()
+	defer failOverCacheMutex.RUnlock()
+	if nodeID, ok := failOverCache[models.NetworkID(network)]; ok {
+		failOverNode, err := logic.GetNodeByID(nodeID)
+		if err == nil {
+			return failOverNode, true
 		}
 	}
 	return
@@ -185,5 +239,6 @@ func CreateFailOver(node models.Node) error {
 		slog.Error("failed to upsert node", "node", node.ID.String(), "error", err)
 		return err
 	}
+	SetFailOverInCache(node)
 	return nil
 }

+ 6 - 4
pro/logic/user_mgmt.go

@@ -216,8 +216,9 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) {
 
 	// create default network groups
 	var NetworkAdminGroup = models.UserGroup{
-		ID:   models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkAdmin)),
-		Name: fmt.Sprintf("%s Admin Group", netID),
+		ID:      models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkAdmin)),
+		Name:    fmt.Sprintf("%s Admin Group", netID),
+		Default: true,
 		NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{
 			netID: {
 				models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkAdmin)): {},
@@ -226,8 +227,9 @@ func CreateDefaultNetworkRolesAndGroups(netID models.NetworkID) {
 		MetaData: fmt.Sprintf("can manage your network `%s` configuration including adding and removing devices.", netID),
 	}
 	var NetworkUserGroup = models.UserGroup{
-		ID:   models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkUser)),
-		Name: fmt.Sprintf("%s User Group", netID),
+		ID:      models.UserGroupID(fmt.Sprintf("%s-%s-grp", netID, models.NetworkUser)),
+		Name:    fmt.Sprintf("%s User Group", netID),
+		Default: true,
 		NetworkRoles: map[models.NetworkID]map[models.UserRoleID]struct{}{
 			netID: {
 				models.UserRoleID(fmt.Sprintf("%s-%s", netID, models.NetworkUser)): {},