Переглянути джерело

NM-122: Auto Relay, auto assignment of Gw (#3697)

* add auto realy handlers and logic funcs

* add pro func connectors

* Add auto relayed peer ips on peer update, set auto relay on gw creation

* add network id to signal, add autorelay nodes to peerudpate

* add autorelay peer update logic

* add nodes to peer update

* revert node model change

* reset auto relayed peers on the relay node on reset, add auto relay nodes to pull

* add logic api to update auto relay node

* add autoassigngw field to node, add logic to swith relay node in relayme udpate api

* add gw nodes to pull

* intilaise gw map

* HA relay functionality

* add autoassign gw option to enrollment key

* publish intant action to auto assign gw

* fix static checks

* unset relay if auto assign removed

* add host node model to auto relay info

* add host node model to auto relay info

* only use hostNode model for gws info

* handle autoassigned gw peer in the update

* handle autoassigned gw peer in the update

* handle peer updates for autoassigned gw peer

* unset auto assigned peer if relayed or failedovered
Abhishek K 4 днів тому
батько
коміт
74fef9fbc6

+ 4 - 0
auth/host_session.go

@@ -309,6 +309,10 @@ func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username
 			}
 
 			newNode, err := logic.UpdateHostNetwork(h, netID, true)
+			if servercfg.IsPro && key.AutoAssignGateway {
+				newNode.AutoAssignGateway = true
+				logic.UpsertNode(newNode)
+			}
 			if err == nil || strings.Contains(err.Error(), "host already part of network") {
 				if len(key.Groups) > 0 {
 					newNode.Tags = make(map[models.TagID]struct{})

+ 1 - 0
controllers/enrollmentkeys.go

@@ -181,6 +181,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 		relayId,
 		false,
 		enrollmentKeyBody.AutoEgress,
+		enrollmentKeyBody.AutoAssignGateway,
 	)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())

+ 13 - 0
controllers/gateway.go

@@ -89,6 +89,9 @@ func createGateway(w http.ResponseWriter, r *http.Request) {
 			if relayedNode.FailedOverBy != uuid.Nil {
 				go logic.ResetFailedOverPeer(&relayedNode)
 			}
+			if relayedNode.AutoRelayedBy != uuid.Nil {
+				go logic.ResetAutoRelayedPeer(&relayedNode)
+			}
 
 		}
 	}
@@ -101,6 +104,12 @@ func createGateway(w http.ResponseWriter, r *http.Request) {
 					mq.PublishPeerUpdate(false)
 				}()
 			}
+
+			go func() {
+				logic.ResetAutoRelayedPeer(&node)
+				mq.PublishPeerUpdate(false)
+			}()
+
 		}
 		if node.IsGw && node.IngressDNS == "" {
 			node.IngressDNS = "1.1.1.1"
@@ -190,6 +199,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
 	}
 	logic.UnsetInternetGw(&node)
 	node.IsGw = false
+	if node.IsAutoRelay {
+		logic.ResetAutoRelay(&node)
+	}
+	node.IsAutoRelay = false
 	logic.UpsertNode(&node)
 	logger.Log(1, r.Header.Get("user"), "deleted gw", nodeid, "on network", netid)
 

+ 6 - 0
controllers/hosts.go

@@ -212,6 +212,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
 		}
 		if r.URL.Query().Get("reset_failovered") == "true" {
 			logic.ResetFailedOverPeer(&node)
+			logic.ResetAutoRelayedPeer(&node)
 			sendPeerUpdate = true
 		}
 	}
@@ -250,6 +251,8 @@ func pull(w http.ResponseWriter, r *http.Request) {
 		EndpointDetection: logic.IsEndpointDetectionEnabled(),
 		DnsNameservers:    hPU.DnsNameservers,
 		ReplacePeers:      hPU.ReplacePeers,
+		AutoRelayNodes:    hPU.AutoRelayNodes,
+		GwNodes:           hPU.GwNodes,
 	}
 
 	logger.Log(1, hostID, host.Name, "completed a pull")
@@ -1231,6 +1234,9 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) {
 		})
 		return
 	}
+	if key.AutoAssignGateway {
+		newNode.AutoAssignGateway = true
+	}
 	if len(key.Groups) > 0 {
 		newNode.Tags = make(map[models.TagID]struct{})
 		for _, tagI := range key.Groups {

+ 5 - 0
controllers/inet_gws.go

@@ -72,6 +72,11 @@ func createInternetGw(w http.ResponseWriter, r *http.Request) {
 				mq.PublishPeerUpdate(false)
 			}()
 		}
+		go func() {
+			logic.ResetAutoRelayedPeer(&node)
+			mq.PublishPeerUpdate(false)
+		}()
+
 	}
 	if node.IsGw && node.IngressDNS == "" {
 		node.IngressDNS = "1.1.1.1"

+ 36 - 0
controllers/node.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"strings"
 
+	"github.com/google/uuid"
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
@@ -639,6 +640,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		)
 		return
 	}
+	if currentNode.IsAutoRelay && !newNode.IsAutoRelay {
+		logic.ResetAutoRelay(newNode)
+	}
+
 	if newNode.IsInternetGateway && len(newNode.InetNodeReq.InetNodeClientIDs) > 0 {
 		err = logic.ValidateInetGwReq(*newNode, newNode.InetNodeReq, newNode.IsInternetGateway && currentNode.IsInternetGateway)
 		if err != nil {
@@ -648,6 +653,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		newNode.RelayedNodes = append(newNode.RelayedNodes, newNode.InetNodeReq.InetNodeClientIDs...)
 		newNode.RelayedNodes = logic.UniqueStrings(newNode.RelayedNodes)
 	}
+
 	relayUpdate := logic.RelayUpdates(&currentNode, newNode)
 	if relayUpdate && newNode.IsRelay {
 		err = logic.ValidateRelay(models.RelayRequest{
@@ -692,6 +698,33 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 	if !newNode.IsInternetGateway {
 		logic.UnsetInternetGw(newNode)
 	}
+	if currentNode.AutoAssignGateway && !newNode.AutoAssignGateway {
+		// if relayed remove it
+		if newNode.IsRelayed {
+			relayNode, err := logic.GetNodeByID(newNode.RelayedBy)
+			if err == nil {
+				logic.RemoveAllFromSlice(relayNode.RelayedNodes, newNode.ID.String())
+				logic.UpsertNode(&relayNode)
+			}
+			newNode.IsRelayed = false
+			newNode.RelayedBy = ""
+		}
+	}
+	if (currentNode.IsRelayed || currentNode.FailedOverBy != uuid.Nil) && newNode.AutoAssignGateway {
+		// if relayed remove it
+		if currentNode.IsRelayed {
+			relayNode, err := logic.GetNodeByID(currentNode.RelayedBy)
+			if err == nil {
+				logic.RemoveAllFromSlice(relayNode.RelayedNodes, currentNode.ID.String())
+				logic.UpsertNode(&relayNode)
+			}
+			newNode.IsRelayed = false
+			newNode.RelayedBy = ""
+		}
+		if currentNode.FailedOverBy != uuid.Nil {
+			logic.ResetAutoRelayedPeer(&currentNode)
+		}
+	}
 	logic.UpsertNode(newNode)
 	logic.GetNodeStatus(newNode, false)
 
@@ -733,6 +766,9 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		// 	mq.HostUpdate(&models.HostUpdate{Host: *host, Action: models.SignalPull})
 		// }
 		mq.PublishPeerUpdate(false)
+		if newNode.AutoAssignGateway {
+			mq.HostUpdate(&models.HostUpdate{Action: models.CheckAutoAssignGw, Host: *host, Node: *newNode})
+		}
 		if servercfg.IsDNSMode() {
 			logic.SetDNS()
 		}

+ 15 - 12
logic/enrollmentkey.go

@@ -38,23 +38,26 @@ var (
 )
 
 // CreateEnrollmentKey - creates a new enrollment key in db
-func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey, autoEgress bool) (*models.EnrollmentKey, error) {
+func CreateEnrollmentKey(uses int, expiration time.Time, networks,
+	tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID,
+	defaultKey, autoEgress, autoAssignGw bool) (*models.EnrollmentKey, error) {
 	newKeyID, err := getUniqueEnrollmentID()
 	if err != nil {
 		return nil, err
 	}
 	k := &models.EnrollmentKey{
-		Value:         newKeyID,
-		Expiration:    time.Time{},
-		UsesRemaining: 0,
-		Unlimited:     unlimited,
-		Networks:      []string{},
-		Tags:          []string{},
-		Type:          models.Undefined,
-		Relay:         relay,
-		Groups:        groups,
-		Default:       defaultKey,
-		AutoEgress:    autoEgress,
+		Value:             newKeyID,
+		Expiration:        time.Time{},
+		UsesRemaining:     0,
+		Unlimited:         unlimited,
+		Networks:          []string{},
+		Tags:              []string{},
+		Type:              models.Undefined,
+		Relay:             relay,
+		Groups:            groups,
+		Default:           defaultKey,
+		AutoEgress:        autoEgress,
+		AutoAssignGateway: autoAssignGw,
 	}
 	if uses > 0 {
 		k.UsesRemaining = uses

+ 16 - 15
logic/enrollmentkey_test.go

@@ -1,11 +1,12 @@
 package logic
 
 import (
-	"github.com/gravitl/netmaker/db"
-	"github.com/gravitl/netmaker/schema"
 	"testing"
 	"time"
 
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
+
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
@@ -19,35 +20,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	t.Run("Can_Not_Create_Key", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false)
 		assert.Nil(t, newKey)
 		assert.NotNil(t, err)
 		assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
 	})
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
+		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false)
 		assert.Nil(t, err)
 		assert.Equal(t, 1, newKey.UsesRemaining)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Time", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false, false)
+		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Networks) == 2)
 	})
 	t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false, false, false)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Tags) == 2)
@@ -70,7 +71,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false)
 	t.Run("Can_Delete_Key", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		err := DeleteEnrollmentKey(newKey.Value, false)
@@ -94,7 +95,7 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
 
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
+	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false)
 	t.Run("Check_initial_uses", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		assert.Equal(t, newKey.UsesRemaining, 1)
@@ -121,9 +122,9 @@ func TestUsability_EnrollmentKey(t *testing.T) {
 
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
-	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false, false)
-	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false)
+	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false)
+	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false, false, false)
+	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false, false)
 	t.Run("Check if valid use key can be used", func(t *testing.T) {
 		assert.Equal(t, key1.UsesRemaining, 1)
 		ok := TryToUseEnrollmentKey(key1)
@@ -162,7 +163,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
 
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false)
 	const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
 	const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
@@ -198,7 +199,7 @@ func TestDeTokenize_EnrollmentKeys(t *testing.T) {
 
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false)
 	const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
 

+ 7 - 0
logic/gateway.go

@@ -198,6 +198,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
 	}
 	node.IsIngressGateway = true
 	node.IsGw = true
+	SetAutoRelay(&node)
 	node.IsInternetGateway = ingress.IsInternetGateway
 	node.IngressGatewayRange = network.AddressRange
 	node.IngressGatewayRange6 = network.AddressRange6
@@ -217,6 +218,9 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
 		if _, exists := FailOverExists(node.Network); exists {
 			ResetFailedOverPeer(&node)
 		}
+
+		ResetAutoRelayedPeer(&node)
+
 	}
 	node.SetLastModified()
 	node.Metadata = ingress.Metadata
@@ -370,6 +374,9 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool
 		if clientNode.FailedOverBy != uuid.Nil {
 			ResetFailedOverPeer(&clientNode)
 		}
+		if clientNode.AutoRelayedBy != uuid.Nil {
+			ResetAutoRelayedPeer(&clientNode)
+		}
 
 		if clientNode.IsRelayed && clientNode.RelayedBy != inetNode.ID.String() {
 			return fmt.Errorf("node %s is being relayed", clientHost.Name)

+ 4 - 0
logic/hosts.go

@@ -340,6 +340,9 @@ func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool)
 			if node.FailedOverBy != uuid.Nil {
 				ResetFailedOverPeer(&node)
 			}
+			if node.AutoRelayedBy != uuid.Nil {
+				ResetAutoRelayedPeer(&node)
+			}
 		}
 	}
 
@@ -396,6 +399,7 @@ func UpdateHostNode(h *models.Host, newNode *models.Node) (publishDeletedNodeUpd
 		publishPeerUpdate = true
 		// reset failover data for this node
 		ResetFailedOverPeer(newNode)
+		ResetAutoRelayedPeer(newNode)
 	}
 	return
 }

+ 1 - 0
logic/networks.go

@@ -307,6 +307,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 		uuid.Nil,
 		true,
 		false,
+		false,
 	)
 
 	return network, nil

+ 3 - 0
logic/nodes.go

@@ -282,6 +282,9 @@ func DeleteNode(node *models.Node, purge bool) error {
 	if node.FailedOverBy != uuid.Nil {
 		ResetFailedOverPeer(node)
 	}
+	if node.AutoRelayedBy != uuid.Nil {
+		ResetAutoRelayedPeer(node)
+	}
 	if node.IsRelay {
 		// unset all the relayed nodes
 		SetRelayedNodes(false, node.ID.String(), node.RelayedNodes)

+ 77 - 20
logic/peers.go

@@ -44,6 +44,25 @@ var (
 	}
 )
 
+var (
+	// ResetAutoRelay - function to reset autorelayed peers on this node
+	ResetAutoRelay = func(autoRelayNode *models.Node) error {
+		return nil
+	}
+	// ResetAutoRelayedPeer - removes relayed peers for node
+	ResetAutoRelayedPeer = func(failedOverNode *models.Node) error {
+		return nil
+	}
+	// GetAutoRelayPeerIps - gets autorelay peerips
+	GetAutoRelayPeerIps = func(peer, node *models.Node) []net.IPNet {
+		return []net.IPNet{}
+	}
+	// SetAutoRelay - sets autorelay flag on the node
+	SetAutoRelay = func(node *models.Node) {
+		node.IsAutoRelay = false
+	}
+)
+
 // GetHostPeerInfo - fetches required peer info per network
 func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
 	peerInfo := models.HostPeerInfo{
@@ -143,6 +162,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		HostNetworkInfo: models.HostInfoMap{},
 		ServerConfig:    GetServerInfo(),
 		DnsNameservers:  GetNameserversForHost(host),
+		AutoRelayNodes:  make(map[models.NetworkID][]models.Node),
+		GwNodes:         make(map[models.NetworkID][]models.Node),
 	}
 	if host.DNS == "no" {
 		hostPeerUpdate.ManageDNS = false
@@ -180,6 +201,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 			(!node.LastCheckIn.IsZero() && time.Since(node.LastCheckIn) > time.Hour) {
 			continue
 		}
+		hostPeerUpdate.Nodes = append(hostPeerUpdate.Nodes, node)
 		acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
 		eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
 		GetNodeEgressInfo(&node, eli, acls)
@@ -252,9 +274,11 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				node.Mutex.Lock()
 			}
 			_, isFailOverPeer := node.FailOverPeers[peer.ID.String()]
+			_, isAutoRelayPeer := node.AutoRelayedPeers[peer.ID.String()]
 			if node.Mutex != nil {
 				node.Mutex.Unlock()
 			}
+
 			if peer.EgressDetails.IsEgressGateway {
 				peerKey := peerHost.PublicKey.String()
 				if isFailOverPeer && peer.FailedOverBy.String() != node.ID.String() {
@@ -267,6 +291,16 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 						}
 					}
 				}
+				if isAutoRelayPeer && peer.AutoRelayedBy.String() != node.ID.String() {
+					// get relay host
+					autoRelayNode, err := GetNodeByID(peer.AutoRelayedBy.String())
+					if err == nil {
+						relayHost, err := GetHost(autoRelayNode.HostID.String())
+						if err == nil {
+							peerKey = relayHost.PublicKey.String()
+						}
+					}
+				}
 				if peer.IsRelayed && (peer.RelayedBy != node.ID.String()) {
 					// get relay host
 					relayNode, err := GetNodeByID(peer.RelayedBy)
@@ -292,20 +326,42 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 			if peer.IsIngressGateway {
 				hostPeerUpdate.EgressRoutes = append(hostPeerUpdate.EgressRoutes, getExtpeersExtraRoutes(node)...)
 			}
-
-			if (node.IsRelayed && node.RelayedBy != peer.ID.String()) ||
-				(peer.IsRelayed && peer.RelayedBy != node.ID.String()) || isFailOverPeer {
-				// if node is relayed and peer is not the relay, set remove to true
-				if _, ok := peerIndexMap[peerHost.PublicKey.String()]; ok {
-					continue
+			var allowedToComm bool
+			if defaultDevicePolicy.Enabled {
+				allowedToComm = true
+			} else {
+				allowedToComm = IsPeerAllowed(node, peer, false)
+			}
+			if allowedToComm {
+				if peer.IsAutoRelay {
+					hostPeerUpdate.AutoRelayNodes[models.NetworkID(peer.Network)] = append(hostPeerUpdate.AutoRelayNodes[models.NetworkID(peer.Network)],
+						peer)
+				}
+				if node.AutoAssignGateway && peer.IsGw {
+					hostPeerUpdate.GwNodes[models.NetworkID(peer.Network)] = append(hostPeerUpdate.GwNodes[models.NetworkID(peer.Network)],
+						peer)
 				}
-				peerConfig.Remove = true
-				hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
-				peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
-				continue
 			}
-			if node.IsRelayed && node.RelayedBy == peer.ID.String() {
-				hostPeerUpdate = SetDefaultGwForRelayedUpdate(node, peer, hostPeerUpdate)
+			shouldCheckRelayed := true
+			if (node.AutoAssignGateway && peer.IsGw && node.RelayedBy != peer.ID.String()) ||
+				(peer.AutoAssignGateway && node.IsGw && peer.RelayedBy != node.ID.String()) {
+				shouldCheckRelayed = false
+			}
+			if shouldCheckRelayed {
+				if (node.IsRelayed && node.RelayedBy != peer.ID.String()) ||
+					(peer.IsRelayed && peer.RelayedBy != node.ID.String()) || isFailOverPeer || isAutoRelayPeer {
+					// if node is relayed and peer is not the relay, set remove to true
+					if _, ok := peerIndexMap[peerHost.PublicKey.String()]; ok {
+						continue
+					}
+					peerConfig.Remove = true
+					hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
+					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
+					continue
+				}
+				if node.IsRelayed && node.RelayedBy == peer.ID.String() {
+					hostPeerUpdate = SetDefaultGwForRelayedUpdate(node, peer, hostPeerUpdate)
+				}
 			}
 
 			uselocal := false
@@ -342,13 +398,17 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 					peerEndpoint = peerHost.EndpointIPv6
 				}
 			}
-			if node.IsRelay && peer.RelayedBy == node.ID.String() && peer.InternetGwID == "" && !peer.IsStatic {
+			if (node.IsRelay && peer.RelayedBy == node.ID.String() && peer.InternetGwID == "") ||
+				(peer.AutoAssignGateway && node.IsGw) && !peer.IsStatic {
 				// don't set endpoint on relayed peer
 				peerEndpoint = nil
 			}
 			if isFailOverPeer && peer.FailedOverBy == node.ID && !peer.IsStatic {
 				peerEndpoint = nil
 			}
+			if isAutoRelayPeer && peer.AutoRelayedBy == node.ID && !peer.IsStatic {
+				peerEndpoint = nil
+			}
 
 			peerConfig.Endpoint = &net.UDPAddr{
 				IP:   peerEndpoint,
@@ -358,12 +418,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 			if uselocal {
 				peerConfig.Endpoint.Port = peerHost.ListenPort
 			}
-			var allowedToComm bool
-			if defaultDevicePolicy.Enabled {
-				allowedToComm = true
-			} else {
-				allowedToComm = IsPeerAllowed(node, peer, false)
-			}
+
 			if peer.Action != models.NODE_DELETE &&
 				!peer.PendingDelete &&
 				peer.Connected &&
@@ -372,7 +427,6 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				(deletedNode == nil || (peer.ID.String() != deletedNode.ID.String())) {
 				peerConfig.AllowedIPs = GetAllowedIPs(&node, &peer, nil) // only append allowed IPs if valid connection
 			}
-
 			var nodePeer wgtypes.PeerConfig
 			if _, ok := peerIndexMap[peerHost.PublicKey.String()]; !ok {
 				hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
@@ -711,6 +765,9 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
 	if peer.IsFailOver {
 		allowedips = append(allowedips, GetFailOverPeerIps(peer, node)...)
 	}
+	if peer.IsAutoRelay {
+		allowedips = append(allowedips, GetAutoRelayPeerIps(peer, node)...)
+	}
 	return allowedips
 }
 

+ 19 - 2
logic/relay.go

@@ -82,6 +82,13 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N
 		}
 		returnnodes = append(returnnodes, node)
 	}
+	relayNode, _ := GetNodeByID(relay)
+	if setRelayed {
+		relayNode.RelayedNodes = relayed
+	} else {
+		relayNode.RelayedNodes = []string{}
+	}
+	UpsertNode(&relayNode)
 	return returnnodes
 }
 
@@ -135,12 +142,15 @@ func ValidateRelay(relay models.RelayRequest, update bool) error {
 		if relayedNode.FailedOverBy != uuid.Nil {
 			ResetFailedOverPeer(&relayedNode)
 		}
+		if relayedNode.AutoRelayedBy != uuid.Nil {
+			ResetAutoRelayedPeer(&relayedNode)
+		}
 	}
 	return err
 }
 
 // UpdateRelayNodes - updates relay nodes
-func updateRelayNodes(relay string, oldNodes []string, newNodes []string) []models.Node {
+func UpdateRelayNodes(relay string, oldNodes []string, newNodes []string) []models.Node {
 	_ = SetRelayedNodes(false, relay, oldNodes)
 	return SetRelayedNodes(true, relay, newNodes)
 }
@@ -163,11 +173,12 @@ func RelayUpdates(currentNode, newNode *models.Node) bool {
 
 // UpdateRelayed - updates a relay's relayed nodes, and sends updates to the relayed nodes over MQ
 func UpdateRelayed(currentNode, newNode *models.Node) {
-	updatenodes := updateRelayNodes(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
+	updatenodes := UpdateRelayNodes(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
 	if len(updatenodes) > 0 {
 		for _, relayedNode := range updatenodes {
 			node := relayedNode
 			ResetFailedOverPeer(&node)
+			ResetAutoRelayedPeer(&node)
 		}
 	}
 }
@@ -201,6 +212,9 @@ func RelayedAllowedIPs(peer, node *models.Node) []net.IPNet {
 		if err != nil {
 			continue
 		}
+		if relayedNode.AutoAssignGateway && node.IsGw {
+			continue
+		}
 		GetNodeEgressInfo(&relayedNode, eli, acls)
 		allowed := getRelayedAddresses(relayedNodeID)
 		if relayedNode.EgressDetails.IsEgressGateway {
@@ -232,6 +246,9 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe
 		if peer.ID == relayed.ID || peer.ID == relay.ID {
 			continue
 		}
+		if relayed.AutoAssignGateway && peer.IsGw {
+			continue
+		}
 		if !IsPeerAllowed(*relayed, peer, true) {
 			continue
 		}

+ 13 - 0
logic/util.go

@@ -166,6 +166,19 @@ func RemoveStringSlice(slice []string, i int) []string {
 	return append(slice[:i], slice[i+1:]...)
 }
 
+// RemoveAllFromSlice removes every occurrence of val from s (stable order).
+func RemoveAllFromSlice[T comparable](s []T, val T) []T {
+	// Reuse the underlying array: write filtered items back into s[:0].
+	out := s[:0]
+	for _, v := range s {
+		if v != val {
+			out = append(out, v)
+		}
+	}
+	// out now contains only the kept items; capacity unchanged, len shrunk.
+	return out
+}
+
 // IsSlicesEqual tells whether a and b contain the same elements.
 // A nil argument is equivalent to an empty slice.
 func IsSlicesEqual(a, b []string) bool {

+ 15 - 0
migrate/migrate.go

@@ -309,6 +309,7 @@ func updateEnrollmentKeys() {
 			uuid.Nil,
 			true,
 			false,
+			false,
 		)
 
 	}
@@ -383,6 +384,20 @@ func updateNodes() {
 			node.Tags = make(map[models.TagID]struct{})
 			logic.UpsertNode(&node)
 		}
+		// deprecate failover  and initialise auto relay fields
+		if node.IsFailOver {
+			node.IsFailOver = false
+			node.FailOverPeers = make(map[string]struct{})
+			node.FailedOverBy = uuid.Nil
+			node.AutoRelayedPeers = make(map[string]struct{})
+			logic.UpsertNode(&node)
+		}
+		if node.FailedOverBy != uuid.Nil || len(node.FailOverPeers) > 0 {
+			node.FailOverPeers = make(map[string]struct{})
+			node.FailedOverBy = uuid.Nil
+			node.AutoRelayedPeers = make(map[string]struct{})
+			logic.UpsertNode(&node)
+		}
 		if node.IsIngressGateway {
 			host, err := logic.GetHost(node.HostID.String())
 			if err == nil {

+ 10 - 0
models/api_node.go

@@ -33,6 +33,10 @@ type ApiNode struct {
 	IsRelayed                     bool                `json:"isrelayed"`
 	IsRelay                       bool                `json:"isrelay"`
 	IsGw                          bool                `json:"is_gw"`
+	IsAutoRelay                   bool                `json:"is_auto_relay"`
+	AutoRelayedPeers              map[string]struct{} `json:"auto_relayed_peers"`
+	AutoAssignGateway             bool                `json:"auto_assign_gw"`
+	AutoRelayedBy                 uuid.UUID           `json:"auto_relayed_by"`
 	RelayedBy                     string              `json:"relayedby" bson:"relayedby" yaml:"relayedby"`
 	RelayedNodes                  []string            `json:"relaynodes" yaml:"relayedNodes"`
 	IsEgressGateway               bool                `json:"isegressgateway"`
@@ -134,10 +138,12 @@ func (a *ApiNode) ConvertToServerNode(currentNode *Node) *Node {
 	}
 	convertedNode.Tags = a.Tags
 	convertedNode.IsGw = a.IsGw
+	convertedNode.IsAutoRelay = a.IsAutoRelay
 	if convertedNode.IsGw {
 		convertedNode.IsRelay = true
 		convertedNode.IsIngressGateway = true
 	}
+	convertedNode.AutoAssignGateway = a.AutoAssignGateway
 	return &convertedNode
 }
 
@@ -189,6 +195,10 @@ func (nm *Node) ConvertToAPINode() *ApiNode {
 	apiNode.IsGw = nm.IsGw
 	apiNode.RelayedBy = nm.RelayedBy
 	apiNode.RelayedNodes = nm.RelayedNodes
+	apiNode.IsAutoRelay = nm.IsAutoRelay
+	apiNode.AutoRelayedBy = nm.AutoRelayedBy
+	apiNode.AutoRelayedPeers = nm.AutoRelayedPeers
+	apiNode.AutoAssignGateway = nm.AutoAssignGateway
 	apiNode.IsIngressGateway = nm.IsIngressGateway
 	apiNode.IngressDns = nm.IngressDNS
 	apiNode.IngressPersistentKeepalive = nm.IngressPersistentKeepalive

+ 23 - 21
models/enrollment_key.go

@@ -43,31 +43,33 @@ const EnrollmentKeyLength = 32
 
 // EnrollmentKey - the key used to register hosts and join them to specific networks
 type EnrollmentKey struct {
-	Expiration    time.Time `json:"expiration"`
-	UsesRemaining int       `json:"uses_remaining"`
-	Value         string    `json:"value"`
-	Networks      []string  `json:"networks"`
-	Unlimited     bool      `json:"unlimited"`
-	Tags          []string  `json:"tags"`
-	Token         string    `json:"token,omitempty"` // B64 value of EnrollmentToken
-	Type          KeyType   `json:"type"`
-	Relay         uuid.UUID `json:"relay"`
-	Groups        []TagID   `json:"groups"`
-	Default       bool      `json:"default"`
-	AutoEgress    bool      `json:"auto_egress"`
+	Expiration        time.Time `json:"expiration"`
+	UsesRemaining     int       `json:"uses_remaining"`
+	Value             string    `json:"value"`
+	Networks          []string  `json:"networks"`
+	Unlimited         bool      `json:"unlimited"`
+	Tags              []string  `json:"tags"`
+	Token             string    `json:"token,omitempty"` // B64 value of EnrollmentToken
+	Type              KeyType   `json:"type"`
+	Relay             uuid.UUID `json:"relay"`
+	Groups            []TagID   `json:"groups"`
+	Default           bool      `json:"default"`
+	AutoEgress        bool      `json:"auto_egress"`
+	AutoAssignGateway bool      `json:"auto_assign_gw"`
 }
 
 // APIEnrollmentKey - used to create enrollment keys via API
 type APIEnrollmentKey struct {
-	Expiration    int64    `json:"expiration" swaggertype:"primitive,integer" format:"int64"`
-	UsesRemaining int      `json:"uses_remaining"`
-	Networks      []string `json:"networks"`
-	Unlimited     bool     `json:"unlimited"`
-	Tags          []string `json:"tags" validate:"required,dive,min=3,max=32"`
-	Type          KeyType  `json:"type"`
-	Relay         string   `json:"relay"`
-	Groups        []TagID  `json:"groups"`
-	AutoEgress    bool     `json:"auto_egress"`
+	Expiration        int64    `json:"expiration" swaggertype:"primitive,integer" format:"int64"`
+	UsesRemaining     int      `json:"uses_remaining"`
+	Networks          []string `json:"networks"`
+	Unlimited         bool     `json:"unlimited"`
+	Tags              []string `json:"tags" validate:"required,dive,min=3,max=32"`
+	Type              KeyType  `json:"type"`
+	Relay             string   `json:"relay"`
+	Groups            []TagID  `json:"groups"`
+	AutoEgress        bool     `json:"auto_egress"`
+	AutoAssignGateway bool     `json:"auto_assign_gw"`
 }
 
 // RegisterResponse - the response to a successful enrollment register

+ 3 - 0
models/host.go

@@ -128,6 +128,8 @@ const (
 	UpdateMetrics HostMqAction = "UPDATE_METRICS"
 	// EgressUpdate - const for egress update action
 	EgressUpdate HostMqAction = "EGRESS_UPDATE"
+	// CHECK_ASSIGN_GW - const for to auto assign gw action
+	CheckAutoAssignGw HostMqAction = "CHECK_AUTO_ASSIGN_GW"
 )
 
 // SignalAction - turn peer signal action
@@ -165,6 +167,7 @@ type Signal struct {
 	ToHostID       string       `json:"to_host_id"`
 	FromNodeID     string       `json:"from_node_id"`
 	ToNodeID       string       `json:"to_node_id"`
+	NetworkID      string       `json:"networkID"`
 	Reply          bool         `json:"reply"`
 	Action         SignalAction `json:"action"`
 	IsPro          bool         `json:"is_pro"`

+ 9 - 0
models/mqtt.go

@@ -13,6 +13,7 @@ type HostPeerInfo struct {
 // HostPeerUpdate - struct for host peer updates
 type HostPeerUpdate struct {
 	Host              Host                  `json:"host"`
+	Nodes             []Node                `json:"nodes"`
 	ChangeDefaultGw   bool                  `json:"change_default_gw"`
 	DefaultGwIp       net.IP                `json:"default_gw_ip"`
 	IsInternetGw      bool                  `json:"is_inet_gw"`
@@ -30,6 +31,8 @@ type HostPeerUpdate struct {
 	NameServers       []string              `json:"name_servers"`
 	DnsNameservers    []Nameserver          `json:"dns_nameservers"`
 	EgressWithDomains []EgressDomain        `json:"egress_with_domains"`
+	AutoRelayNodes    map[NetworkID][]Node  `json:"auto_relay_nodes"`
+	GwNodes           map[NetworkID][]Node  `json:"gw_nodes"`
 	ServerConfig
 	OldPeerUpdateFields
 }
@@ -132,3 +135,9 @@ type FwUpdate struct {
 type FailOverMeReq struct {
 	NodeID string `json:"node_id"`
 }
+
+// AutoRelayMeReq - struct for autorelay req
+type AutoRelayMeReq struct {
+	NodeID        string `json:"node_id"`
+	AutoRelayGwID string `json:"auto_relay_gw_id"`
+}

+ 25 - 21
models/node.go

@@ -87,34 +87,38 @@ type CommonNode struct {
 	IsGw                bool      `json:"is_gw"             yaml:"is_gw"`
 	RelayedNodes        []string  `json:"relaynodes"          yaml:"relayedNodes"`
 	IngressDNS          string    `json:"ingressdns"          yaml:"ingressdns"`
+	AutoAssignGateway   bool      `json:"auto_assign_gw"`
 }
 
 // Node - a model of a network node
 type Node struct {
 	CommonNode
-	PendingDelete              bool                 `json:"pendingdelete"           bson:"pendingdelete"           yaml:"pendingdelete"`
-	LastModified               time.Time            `json:"lastmodified"            bson:"lastmodified"            yaml:"lastmodified"`
-	LastCheckIn                time.Time            `json:"lastcheckin"             bson:"lastcheckin"             yaml:"lastcheckin"`
-	LastPeerUpdate             time.Time            `json:"lastpeerupdate"          bson:"lastpeerupdate"          yaml:"lastpeerupdate"`
-	ExpirationDateTime         time.Time            `json:"expdatetime"             bson:"expdatetime"             yaml:"expdatetime"`
-	EgressGatewayNatEnabled    bool                 `json:"egressgatewaynatenabled" bson:"egressgatewaynatenabled" yaml:"egressgatewaynatenabled"`
-	EgressGatewayRequest       EgressGatewayRequest `json:"egressgatewayrequest"    bson:"egressgatewayrequest"    yaml:"egressgatewayrequest"`
-	IngressGatewayRange        string               `json:"ingressgatewayrange"     bson:"ingressgatewayrange"     yaml:"ingressgatewayrange"`
-	IngressGatewayRange6       string               `json:"ingressgatewayrange6"    bson:"ingressgatewayrange6"    yaml:"ingressgatewayrange6"`
-	IngressPersistentKeepalive int32                `json:"ingresspersistentkeepalive"     bson:"ingresspersistentkeepalive"     yaml:"ingresspersistentkeepalive"`
-	IngressMTU                 int32                `json:"ingressmtu"     bson:"ingressmtu"     yaml:"ingressmtu"`
+	PendingDelete              bool                 `json:"pendingdelete"`
+	LastModified               time.Time            `json:"lastmodified"`
+	LastCheckIn                time.Time            `json:"lastcheckin"`
+	LastPeerUpdate             time.Time            `json:"lastpeerupdate"`
+	ExpirationDateTime         time.Time            `json:"expdatetime"`
+	EgressGatewayNatEnabled    bool                 `json:"egressgatewaynatenabled"`
+	EgressGatewayRequest       EgressGatewayRequest `json:"egressgatewayrequest"`
+	IngressGatewayRange        string               `json:"ingressgatewayrange"`
+	IngressGatewayRange6       string               `json:"ingressgatewayrange6"`
+	IngressPersistentKeepalive int32                `json:"ingresspersistentkeepalive"`
+	IngressMTU                 int32                `json:"ingressmtu"`
 	Metadata                   string               `json:"metadata"`
 	// == PRO ==
-	DefaultACL        string              `json:"defaultacl,omitempty"    bson:"defaultacl,omitempty"    yaml:"defaultacl,omitempty"    validate:"checkyesornoorunset"`
-	OwnerID           string              `json:"ownerid,omitempty"       bson:"ownerid,omitempty"       yaml:"ownerid,omitempty"`
-	IsFailOver        bool                `json:"is_fail_over"                                           yaml:"is_fail_over"`
-	FailOverPeers     map[string]struct{} `json:"fail_over_peers"                                       yaml:"fail_over_peers"`
-	FailedOverBy      uuid.UUID           `json:"failed_over_by"                                         yaml:"failed_over_by"`
-	IsInternetGateway bool                `json:"isinternetgateway"                                      yaml:"isinternetgateway"`
-	InetNodeReq       InetNodeReq         `json:"inet_node_req"                                          yaml:"inet_node_req"`
-	InternetGwID      string              `json:"internetgw_node_id"                                     yaml:"internetgw_node_id"`
-	AdditionalRagIps  []net.IP            `json:"additional_rag_ips"                                     yaml:"additional_rag_ips"                                     swaggertype:"array,number"`
-	Tags              map[TagID]struct{}  `json:"tags" yaml:"tags"`
+	DefaultACL        string              `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
+	OwnerID           string              `json:"ownerid,omitempty"`
+	IsFailOver        bool                `json:"is_fail_over"`
+	IsAutoRelay       bool                `json:"is_auto_relay"`
+	AutoRelayedPeers  map[string]struct{} `json:"auto_relayed_peers"`
+	AutoRelayedBy     uuid.UUID           `json:"auto_relayed_by"`
+	FailOverPeers     map[string]struct{} `json:"fail_over_peers"`
+	FailedOverBy      uuid.UUID           `json:"failed_over_by"`
+	IsInternetGateway bool                `json:"isinternetgateway"`
+	InetNodeReq       InetNodeReq         `json:"inet_node_req"`
+	InternetGwID      string              `json:"internetgw_node_id"`
+	AdditionalRagIps  []net.IP            `json:"additional_rag_ips" swaggertype:"array,number"`
+	Tags              map[TagID]struct{}  `json:"tags"`
 	IsStatic          bool                `json:"is_static"`
 	IsUserNode        bool                `json:"is_user_node"`
 	StaticNode        ExtClient           `json:"static_node"`

+ 2 - 3
models/structs.go

@@ -267,12 +267,11 @@ type HostPull struct {
 	NameServers       []string              `json:"name_servers"`
 	EgressWithDomains []EgressDomain        `json:"egress_with_domains"`
 	DnsNameservers    []Nameserver          `json:"dns_nameservers"`
+	AutoRelayNodes    map[NetworkID][]Node  `json:"auto_relay_nodes"`
+	GwNodes           map[NetworkID][]Node  `json:"gw_nodes"`
 	ReplacePeers      bool                  `json:"replace_peers"`
 }
 
-type DefaultGwInfo struct {
-}
-
 // NodeGet - struct for a single node get response
 type NodeGet struct {
 	Node         Node                 `json:"node" bson:"node" yaml:"node"`

+ 23 - 4
mq/handlers.go

@@ -108,21 +108,28 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
 	case models.CheckIn:
 		sendPeerUpdate = HandleHostCheckin(&hostUpdate.Host, currentHost)
 	case models.Acknowledgement:
+		nodes, err := logic.GetAllNodes()
+		if err != nil {
+			return
+		}
 		hu := hostactions.GetAction(currentHost.ID.String())
 		if hu != nil {
 			if err = HostUpdate(hu); err != nil {
 				slog.Error("failed to send new node to host", "name", hostUpdate.Host.Name, "id", currentHost.ID, "error", err)
 				return
 			} else {
-				nodes, err := logic.GetAllNodes()
-				if err != nil {
-					return
-				}
+
 				if err = PublishSingleHostPeerUpdate(currentHost, nodes, nil, nil, false, nil); err != nil {
 					slog.Error("failed peers publish after join acknowledged", "name", hostUpdate.Host.Name, "id", currentHost.ID, "error", err)
 					return
 				}
 			}
+		} else {
+			// send latest host update
+			HostUpdate(&models.HostUpdate{
+				Action: models.UpdateHost,
+				Host:   *currentHost})
+			PublishSingleHostPeerUpdate(currentHost, nodes, nil, nil, false, nil)
 		}
 	case models.UpdateHost:
 		if hostUpdate.Host.PublicKey != currentHost.PublicKey {
@@ -188,6 +195,18 @@ func SignalPeer(signal models.Signal) {
 		logger.Log(0, msg)
 		return
 	}
+	node, err := logic.GetNodeByID(signal.FromNodeID)
+	if err != nil {
+		return
+	}
+	peer, err := logic.GetNodeByID(signal.ToNodeID)
+	if err != nil {
+		return
+	}
+	if node.Network != peer.Network {
+		return
+	}
+	signal.NetworkID = node.Network
 	signal.IsPro = servercfg.IsPro
 	peerHost, err := logic.GetHost(signal.ToHostID)
 	if err != nil {

+ 621 - 0
pro/controllers/auto_relay.go

@@ -0,0 +1,621 @@
+package controllers
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"net/http"
+
+	"github.com/google/uuid"
+	"github.com/gorilla/mux"
+	controller "github.com/gravitl/netmaker/controllers"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
+	proLogic "github.com/gravitl/netmaker/pro/logic"
+	"github.com/gravitl/netmaker/schema"
+	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/exp/slog"
+)
+
+// AutoRelayHandlers - handlers for AutoRelay
+func AutoRelayHandlers(r *mux.Router) {
+	r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Authorize(true, false, "host", http.HandlerFunc(getAutoRelayGws))).
+		Methods(http.MethodGet)
+	r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", logic.SecurityCheck(true, http.HandlerFunc(setAutoRelay))).
+		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", logic.SecurityCheck(true, http.HandlerFunc(unsetAutoRelay))).
+		Methods(http.MethodDelete)
+	r.HandleFunc("/api/v1/node/{network}/auto_relay/reset", logic.SecurityCheck(true, http.HandlerFunc(resetAutoRelayGw))).
+		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.Authorize(true, false, "host", http.HandlerFunc(autoRelayME))).
+		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.Authorize(true, false, "host", http.HandlerFunc(autoRelayMEUpdate))).
+		Methods(http.MethodPut)
+	r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_check", controller.Authorize(true, false, "host", http.HandlerFunc(checkautoRelayCtx))).
+		Methods(http.MethodGet)
+}
+
+// @Summary     Get auto relay nodes
+// @Router      /api/v1/node/{nodeid}/auto_relay [get]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Success     200 {object} models.Node
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     404 {object} models.ErrorResponse
+func getAutoRelayGws(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	autoRelayNodes := proLogic.DoesAutoRelayExist(node.Network)
+	if len(autoRelayNodes) == 0 {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("autorelay node not found"), "notfound"),
+		)
+		return
+	}
+	defaultPolicy, err := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	returnautoRelayNodes := []models.Node{}
+	if !defaultPolicy.Enabled {
+		for _, autoRelayNode := range autoRelayNodes {
+			if logic.IsPeerAllowed(node, autoRelayNode, false) {
+				returnautoRelayNodes = append(returnautoRelayNodes, autoRelayNode)
+			}
+		}
+	} else {
+		returnautoRelayNodes = autoRelayNodes
+	}
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponseWithJson(w, r, returnautoRelayNodes, "get autorelay node successfully")
+}
+
+// @Summary     Create AutoRelay node
+// @Router      /api/v1/node/{nodeid}/auto_relay [post]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Success     200 {object} models.Node
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func setAutoRelay(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		slog.Error("failed to get node:", "error", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	err = proLogic.CreateAutoRelay(node)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	go mq.PublishPeerUpdate(false)
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponseWithJson(w, r, node, "created autorelay successfully")
+}
+
+// @Summary     Reset AutoRelay for a network
+// @Router      /api/v1/node/{network}/auto_relay/reset [post]
+// @Tags        PRO
+// @Param       network path string true "Network ID"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     500 {object} models.ErrorResponse
+func resetAutoRelayGw(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	net := params["network"]
+	nodes, err := logic.GetNetworkNodes(net)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	for _, node := range nodes {
+		if node.AutoRelayedBy != uuid.Nil {
+			node.AutoRelayedBy = uuid.Nil
+			if node.Mutex != nil {
+				node.Mutex.Lock()
+			}
+			node.AutoRelayedPeers = make(map[string]struct{})
+			if node.Mutex != nil {
+				node.Mutex.Unlock()
+			}
+			logic.UpsertNode(&node)
+		}
+	}
+	go mq.PublishPeerUpdate(false)
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "autorelay has been reset successfully")
+}
+
+// @Summary     Delete autorelay node
+// @Router      /api/v1/node/{nodeid}/auto_relay [delete]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Success     200 {object} models.Node
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func unsetAutoRelay(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		slog.Error("failed to get node:", "error", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	node.IsAutoRelay = false
+	// Reset AutoRelayed Peers
+	err = logic.UpsertNode(&node)
+	if err != nil {
+		slog.Error("failed to upsert node", "node", node.ID.String(), "error", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	if servercfg.CacheEnabled() {
+		proLogic.RemoveAutoRelayFromCache(node.Network)
+	}
+	go func() {
+		proLogic.ResetAutoRelay(&node)
+		mq.PublishPeerUpdate(false)
+	}()
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponseWithJson(w, r, node, "deleted autorelay successfully")
+}
+
+// @Summary     AutoRelay me
+// @Router      /api/v1/node/{nodeid}/auto_relay_me [post]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Accept      json
+// @Param       body body models.AutoRelayMeReq true "AutoRelay request"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func autoRelayME(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	var autoRelayReq models.AutoRelayMeReq
+	err = json.NewDecoder(r.Body).Decode(&autoRelayReq)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	autoRelayNode, err := logic.GetNodeByID(autoRelayReq.AutoRelayGwID)
+	if err != nil {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf("req-from: %s, autorelay node doesn't exist in the network", host.Name),
+				"badrequest",
+			),
+		)
+		return
+	}
+
+	var sendPeerUpdate bool
+	peerNode, err := logic.GetNodeByID(autoRelayReq.NodeID)
+	if err != nil {
+		slog.Error("peer not found: ", "nodeid", autoRelayReq.NodeID, "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer not found"), "badrequest"),
+		)
+		return
+	}
+	eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
+	acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
+	logic.GetNodeEgressInfo(&node, eli, acls)
+	logic.GetNodeEgressInfo(&peerNode, eli, acls)
+	logic.GetNodeEgressInfo(&autoRelayNode, eli, acls)
+	if peerNode.IsAutoRelay {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as autorelay"), "badrequest"),
+		)
+		return
+	}
+	if node.IsAutoRelay {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is acting as autorelay"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsAutoRelay {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as autorelay"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelayed && node.RelayedBy == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelay && peerNode.RelayedBy == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"),
+		)
+		return
+	}
+	if (node.InternetGwID != "" && autoRelayNode.IsInternetGateway && node.InternetGwID != autoRelayNode.ID.String()) ||
+		(peerNode.InternetGwID != "" && autoRelayNode.IsInternetGateway && peerNode.InternetGwID != autoRelayNode.ID.String()) {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node acting as internet gw for the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	err = proLogic.SetAutoRelayCtx(autoRelayNode, node, peerNode)
+	if err != nil {
+		slog.Debug("failed to create autorelay", "id", node.ID.String(),
+			"network", node.Network, "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(fmt.Errorf("failed to create autorelay: %v", err), "internal"),
+		)
+		return
+	}
+	slog.Info(
+		"[auto-relay] created relay on node",
+		"node",
+		node.ID.String(),
+		"network",
+		node.Network,
+	)
+	sendPeerUpdate = true
+
+	if sendPeerUpdate {
+		go mq.PublishPeerUpdate(false)
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "relayed successfully")
+}
+
+// @Summary     AutoRelay me
+// @Router      /api/v1/node/{nodeid}/auto_relay_me [put]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Accept      json
+// @Param       body body models.AutoRelayMeReq true "AutoRelay request"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func autoRelayMEUpdate(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	var autoRelayReq models.AutoRelayMeReq
+	err = json.NewDecoder(r.Body).Decode(&autoRelayReq)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	if autoRelayReq.AutoRelayGwID == "" {
+		// unset current gw
+		if node.RelayedBy != "" {
+			// unset relayed node from the curr relay
+			currRelayNode, err := logic.GetNodeByID(node.RelayedBy)
+			if err == nil {
+				newRelayedNodes := logic.RemoveAllFromSlice(currRelayNode.RelayedNodes, node.ID.String())
+				logic.UpdateRelayNodes(currRelayNode.ID.String(), currRelayNode.RelayedNodes, newRelayedNodes)
+			}
+		}
+		allNodes, err := logic.GetAllNodes()
+		if err == nil {
+			mq.PublishSingleHostPeerUpdate(host, allNodes, nil, nil, false, nil)
+		}
+		go mq.PublishPeerUpdate(false)
+		logic.ReturnSuccessResponse(w, r, "unrelayed successfully")
+		return
+	}
+	autoRelayNode, err := logic.GetNodeByID(autoRelayReq.AutoRelayGwID)
+	if err != nil {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf("req-from: %s, autorelay node doesn't exist in the network", host.Name),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if !autoRelayNode.IsGw {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf(" autorelay node is not a gw"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.AutoAssignGateway {
+		if node.RelayedBy != autoRelayReq.AutoRelayGwID {
+			if node.RelayedBy != "" {
+				// unset relayed node from the curr relay
+				currRelayNode, err := logic.GetNodeByID(node.RelayedBy)
+				if err == nil {
+					newRelayedNodes := logic.RemoveAllFromSlice(currRelayNode.RelayedNodes, node.ID.String())
+					logic.UpdateRelayNodes(currRelayNode.ID.String(), currRelayNode.RelayedNodes, newRelayedNodes)
+				}
+			}
+			newNodes := []string{node.ID.String()}
+			newNodes = append(newNodes, autoRelayNode.RelayedNodes...)
+			logic.UpdateRelayNodes(autoRelayNode.ID.String(), autoRelayNode.RelayedNodes, newNodes)
+			go mq.PublishPeerUpdate(false)
+		}
+		w.Header().Set("Content-Type", "application/json")
+		logic.ReturnSuccessResponse(w, r, "relayed successfully")
+		return
+	}
+	if node.AutoRelayedBy == uuid.Nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node is not auto relayed"), "badrequest"))
+		return
+	}
+
+	if !autoRelayNode.IsAutoRelay {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("requested node is not a auto relay node"), "badrequest"))
+		return
+	}
+	if node.AutoRelayedBy == autoRelayNode.ID {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("already using requested relay node"), "badrequest"))
+		return
+	}
+	node.AutoRelayedBy = autoRelayNode.ID
+	logic.UpsertNode(&node)
+	slog.Info(
+		"[auto-relay] created relay on node",
+		"node",
+		node.ID.String(),
+		"network",
+		node.Network,
+	)
+	go mq.PublishPeerUpdate(false)
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "relayed successfully")
+}
+
+// @Summary     checkautoRelayCtx
+// @Router      /api/v1/node/{nodeid}/auto_relay_check [get]
+// @Tags        PRO
+// @Param       nodeid path string true "Node ID"
+// @Accept      json
+// @Param       body body models.AutoRelayMeReq true "autorelay request"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func checkautoRelayCtx(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	nodeid := params["nodeid"]
+	// confirm host exists
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	var autoRelayReq models.AutoRelayMeReq
+	err = json.NewDecoder(r.Body).Decode(&autoRelayReq)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	autoRelayNode, err := logic.GetNodeByID(autoRelayReq.AutoRelayGwID)
+	if err != nil {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				fmt.Errorf("req-from: %s, autorelay node doesn't exist in the network", host.Name),
+				"badrequest",
+			),
+		)
+		return
+	}
+
+	peerNode, err := logic.GetNodeByID(autoRelayReq.NodeID)
+	if err != nil {
+		slog.Error("peer not found: ", "nodeid", autoRelayReq.NodeID, "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer not found"), "badrequest"),
+		)
+		return
+	}
+	eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
+	acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
+	logic.GetNodeEgressInfo(&node, eli, acls)
+	logic.GetNodeEgressInfo(&peerNode, eli, acls)
+	logic.GetNodeEgressInfo(&autoRelayNode, eli, acls)
+	if peerNode.IsAutoRelay {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as autorelay"), "badrequest"),
+		)
+		return
+	}
+	if node.IsAutoRelay {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is acting as autorelay"), "badrequest"),
+		)
+		return
+	}
+	if peerNode.IsAutoRelay {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("peer is acting as autorelay"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelayed && node.RelayedBy == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"),
+		)
+		return
+	}
+	if node.IsRelay && peerNode.RelayedBy == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"),
+		)
+		return
+	}
+	if (node.InternetGwID != "" && autoRelayNode.IsInternetGateway && node.InternetGwID != autoRelayNode.ID.String()) ||
+		(peerNode.InternetGwID != "" && autoRelayNode.IsInternetGateway && peerNode.InternetGwID != autoRelayNode.ID.String()) {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node acting as internet gw for the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("node using a internet gw by the peer node"),
+				"badrequest",
+			),
+		)
+		return
+	}
+	if ok := logic.IsPeerAllowed(node, peerNode, true); !ok {
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(
+				errors.New("peers are not allowed to communicate"),
+				"badrequest",
+			),
+		)
+		return
+	}
+
+	err = proLogic.CheckAutoRelayCtx(autoRelayNode, node, peerNode)
+	if err != nil {
+		slog.Error("autorelay ctx cannot be set ", "error", err)
+		logic.ReturnErrorResponse(
+			w,
+			r,
+			logic.FormatError(fmt.Errorf("autorelay ctx cannot be set: %v", err), "internal"),
+		)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+	logic.ReturnSuccessResponse(w, r, "autorelay can be set")
+}

+ 10 - 0
pro/initialize.go

@@ -36,6 +36,7 @@ func InitPro() {
 		proControllers.EventHandlers,
 		proControllers.TagHandlers,
 		proControllers.NetworkHandlers,
+		proControllers.AutoRelayHandlers,
 	)
 	controller.ListRoles = proControllers.ListRoles
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
@@ -92,6 +93,9 @@ func InitPro() {
 		}
 		proLogic.LoadNodeMetricsToCache()
 		proLogic.InitFailOverCache()
+		if servercfg.CacheEnabled() {
+			proLogic.InitAutoRelayCache()
+		}
 		auth.ResetIDPSyncHook()
 		email.Init()
 		go proLogic.EventWatcher()
@@ -101,6 +105,12 @@ func InitPro() {
 	logic.FailOverExists = proLogic.FailOverExists
 	logic.CreateFailOver = proLogic.CreateFailOver
 	logic.GetFailOverPeerIps = proLogic.GetFailOverPeerIps
+
+	logic.ResetAutoRelay = proLogic.ResetAutoRelay
+	logic.ResetAutoRelayedPeer = proLogic.ResetAutoRelayedPeer
+	logic.SetAutoRelay = proLogic.SetAutoRelay
+	logic.GetAutoRelayPeerIps = proLogic.GetAutoRelayPeerIps
+
 	logic.DenyClientNodeAccess = proLogic.DenyClientNode
 	logic.IsClientNodeAllowed = proLogic.IsClientNodeAllowed
 	logic.AllowClientNodeAccess = proLogic.RemoveDeniedNodeFromClient

+ 302 - 0
pro/logic/auto_relay.go

@@ -0,0 +1,302 @@
+package logic
+
+import (
+	"context"
+	"errors"
+	"net"
+	"sync"
+
+	"github.com/google/uuid"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/schema"
+	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/exp/slog"
+)
+
+var autoRelayCtxMutex = &sync.RWMutex{}
+var autoRelayCacheMutex = &sync.RWMutex{}
+var autoRelayCache = make(map[models.NetworkID][]string)
+
+func InitAutoRelayCache() {
+	autoRelayCacheMutex.Lock()
+	defer autoRelayCacheMutex.Unlock()
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		return
+	}
+	for _, node := range allNodes {
+		if node.IsAutoRelay {
+			autoRelayCache[models.NetworkID(node.Network)] = append(autoRelayCache[models.NetworkID(node.Network)], node.ID.String())
+		}
+	}
+
+}
+func SetAutoRelay(node *models.Node) {
+	node.IsAutoRelay = true
+}
+
+func CheckAutoRelayCtx(autoRelayNode, victimNode, peerNode models.Node) error {
+	autoRelayCtxMutex.RLock()
+	defer autoRelayCtxMutex.RUnlock()
+	if peerNode.AutoRelayedPeers == nil {
+		return nil
+	}
+	if victimNode.AutoRelayedPeers == nil {
+		return nil
+	}
+	if peerNode.Mutex != nil {
+		peerNode.Mutex.Lock()
+	}
+	_, peerHasAutoRelayed := peerNode.AutoRelayedPeers[victimNode.ID.String()]
+	if peerNode.Mutex != nil {
+		peerNode.Mutex.Unlock()
+	}
+	if victimNode.Mutex != nil {
+		victimNode.Mutex.Lock()
+	}
+	_, victimHasAutoRelayed := victimNode.AutoRelayedPeers[peerNode.ID.String()]
+	if victimNode.Mutex != nil {
+		victimNode.Mutex.Unlock()
+	}
+	if peerHasAutoRelayed && victimHasAutoRelayed &&
+		victimNode.AutoRelayedBy == autoRelayNode.ID && peerNode.AutoRelayedBy == autoRelayNode.ID {
+		return errors.New("auto relay ctx is already set")
+	}
+	return nil
+}
+func SetAutoRelayCtx(autoRelayNode, victimNode, peerNode models.Node) error {
+	autoRelayCtxMutex.Lock()
+	defer autoRelayCtxMutex.Unlock()
+	if peerNode.AutoRelayedPeers == nil {
+		peerNode.AutoRelayedPeers = make(map[string]struct{})
+	}
+	if victimNode.AutoRelayedPeers == nil {
+		victimNode.AutoRelayedPeers = make(map[string]struct{})
+	}
+	if peerNode.Mutex != nil {
+		peerNode.Mutex.Lock()
+	}
+	_, peerHasAutoRelayed := peerNode.AutoRelayedPeers[victimNode.ID.String()]
+	if peerNode.Mutex != nil {
+		peerNode.Mutex.Unlock()
+	}
+	if victimNode.Mutex != nil {
+		victimNode.Mutex.Lock()
+	}
+	_, victimHasAutoRelayed := victimNode.AutoRelayedPeers[peerNode.ID.String()]
+	if victimNode.Mutex != nil {
+		victimNode.Mutex.Unlock()
+	}
+	if peerHasAutoRelayed && victimHasAutoRelayed &&
+		victimNode.AutoRelayedBy == autoRelayNode.ID && peerNode.AutoRelayedBy == autoRelayNode.ID {
+		return errors.New("auto relay ctx is already set")
+	}
+	if peerNode.Mutex != nil {
+		peerNode.Mutex.Lock()
+	}
+	peerNode.AutoRelayedPeers[victimNode.ID.String()] = struct{}{}
+	if peerNode.Mutex != nil {
+		peerNode.Mutex.Unlock()
+	}
+	if victimNode.Mutex != nil {
+		victimNode.Mutex.Lock()
+	}
+	victimNode.AutoRelayedPeers[peerNode.ID.String()] = struct{}{}
+	if victimNode.Mutex != nil {
+		victimNode.Mutex.Unlock()
+	}
+	victimNode.AutoRelayedBy = autoRelayNode.ID
+	// peerNode.AutoRelayedBy = autoRelayNode.ID
+	if err := logic.UpsertNode(&victimNode); err != nil {
+		return err
+	}
+	if err := logic.UpsertNode(&peerNode); err != nil {
+		return err
+	}
+	return nil
+}
+
+// GetAutoRelayNode - gets the host acting as autoRelay
+func GetAutoRelayNode(network string, allNodes []models.Node) (models.Node, error) {
+	nodes := logic.GetNetworkNodesMemory(allNodes, network)
+	for _, node := range nodes {
+		if node.IsAutoRelay {
+			return node, nil
+		}
+	}
+	return models.Node{}, errors.New("auto relay not found")
+}
+
+func RemoveAutoRelayFromCache(network string) {
+	autoRelayCacheMutex.Lock()
+	defer autoRelayCacheMutex.Unlock()
+	delete(autoRelayCache, models.NetworkID(network))
+}
+
+func SetAutoRelayInCache(node models.Node) {
+	autoRelayCacheMutex.Lock()
+	defer autoRelayCacheMutex.Unlock()
+	autoRelayCache[models.NetworkID(node.Network)] = append(autoRelayCache[models.NetworkID(node.Network)], node.ID.String())
+}
+
+// DoesAutoRelayExist - checks if autorelay exists already in the network
+func DoesAutoRelayExist(network string) (autoRelayNodes []models.Node) {
+	autoRelayCacheMutex.RLock()
+	defer autoRelayCacheMutex.RUnlock()
+	if !servercfg.CacheEnabled() {
+		nodes, _ := logic.GetNetworkNodes(network)
+		for _, node := range nodes {
+			if node.IsAutoRelay {
+				autoRelayNodes = append(autoRelayNodes, node)
+			}
+		}
+	}
+	if nodeIDs, ok := autoRelayCache[models.NetworkID(network)]; ok {
+		for _, nodeID := range nodeIDs {
+			autoRelayNode, err := logic.GetNodeByID(nodeID)
+			if err == nil {
+				autoRelayNodes = append(autoRelayNodes, autoRelayNode)
+			}
+		}
+
+	}
+	return
+}
+
+// ResetAutoRelayedPeer - removes auto relayed over node from network peers
+func ResetAutoRelayedPeer(autoRelayedNode *models.Node) error {
+	nodes, err := logic.GetNetworkNodes(autoRelayedNode.Network)
+	if err != nil {
+		return err
+	}
+	autoRelayedNode.AutoRelayedBy = uuid.Nil
+	autoRelayedNode.AutoRelayedPeers = make(map[string]struct{})
+	err = logic.UpsertNode(autoRelayedNode)
+	if err != nil {
+		return err
+	}
+	for _, node := range nodes {
+		if node.AutoRelayedPeers == nil || node.ID == autoRelayedNode.ID {
+			continue
+		}
+		delete(node.AutoRelayedPeers, autoRelayedNode.ID.String())
+		logic.UpsertNode(&node)
+	}
+	return nil
+}
+
+// ResetAutoRelay - reset autorelayed peers
+func ResetAutoRelay(autoRelayNode *models.Node) error {
+	// Unset autorelayed peers
+	nodes, err := logic.GetNetworkNodes(autoRelayNode.Network)
+	if err != nil {
+		return err
+	}
+	for _, node := range nodes {
+		if node.AutoRelayedBy == autoRelayNode.ID {
+			node.AutoRelayedBy = uuid.Nil
+			node.AutoRelayedPeers = make(map[string]struct{})
+			logic.UpsertNode(&node)
+			for _, peer := range nodes {
+				if peer.ID == node.ID {
+					continue
+				}
+				if _, ok := peer.AutoRelayedPeers[node.ID.String()]; ok {
+					delete(peer.AutoRelayedPeers, node.ID.String())
+					logic.UpsertNode(&peer)
+				}
+			}
+		}
+	}
+	return nil
+}
+
+// GetAutoRelayPeerIps - adds the autorelayed peerIps by the peer
+func GetAutoRelayPeerIps(peer, node *models.Node) []net.IPNet {
+	allowedips := []net.IPNet{}
+	eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
+	acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
+	for autoRelayedpeerID := range node.AutoRelayedPeers {
+		autoRelayedpeer, err := logic.GetNodeByID(autoRelayedpeerID)
+		if err == nil && (autoRelayedpeer.AutoRelayedBy == peer.ID || node.AutoRelayedBy == peer.ID) {
+			logic.GetNodeEgressInfo(&autoRelayedpeer, eli, acls)
+			if autoRelayedpeer.Address.IP != nil {
+				allowed := net.IPNet{
+					IP:   autoRelayedpeer.Address.IP,
+					Mask: net.CIDRMask(32, 32),
+				}
+				allowedips = append(allowedips, allowed)
+			}
+			if autoRelayedpeer.Address6.IP != nil {
+				allowed := net.IPNet{
+					IP:   autoRelayedpeer.Address6.IP,
+					Mask: net.CIDRMask(128, 128),
+				}
+				allowedips = append(allowedips, allowed)
+			}
+			if autoRelayedpeer.EgressDetails.IsEgressGateway {
+				allowedips = append(allowedips, logic.GetEgressIPs(&autoRelayedpeer)...)
+			}
+			if autoRelayedpeer.IsRelay {
+				for _, id := range autoRelayedpeer.RelayedNodes {
+					rNode, _ := logic.GetNodeByID(id)
+					logic.GetNodeEgressInfo(&rNode, eli, acls)
+					if rNode.Address.IP != nil {
+						allowed := net.IPNet{
+							IP:   rNode.Address.IP,
+							Mask: net.CIDRMask(32, 32),
+						}
+						allowedips = append(allowedips, allowed)
+					}
+					if rNode.Address6.IP != nil {
+						allowed := net.IPNet{
+							IP:   rNode.Address6.IP,
+							Mask: net.CIDRMask(128, 128),
+						}
+						allowedips = append(allowedips, allowed)
+					}
+					if rNode.EgressDetails.IsEgressGateway {
+						allowedips = append(allowedips, logic.GetEgressIPs(&rNode)...)
+					}
+				}
+			}
+			// handle ingress gateway peers
+			if autoRelayedpeer.IsIngressGateway {
+				extPeers, _, _, err := logic.GetExtPeers(&autoRelayedpeer, node)
+				if err != nil {
+					logger.Log(2, "could not retrieve ext peers for ", peer.ID.String(), err.Error())
+				}
+				for _, extPeer := range extPeers {
+					allowedips = append(allowedips, extPeer.AllowedIPs...)
+				}
+			}
+		}
+	}
+	return allowedips
+}
+
+func CreateAutoRelay(node models.Node) error {
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		return err
+	}
+	if host.OS != models.OS_Types.Linux {
+		return errors.New("only linux nodes are allowed to be set as autoRelay")
+	}
+	if node.IsRelayed {
+		return errors.New("relayed node cannot be set as autoRelay")
+	}
+	node.IsAutoRelay = true
+	err = logic.UpsertNode(&node)
+	if err != nil {
+		slog.Error("failed to upsert node", "node", node.ID.String(), "error", err)
+		return err
+	}
+	if servercfg.CacheEnabled() {
+		SetAutoRelayInCache(node)
+	}
+	return nil
+}

+ 2 - 1
pro/logic/migrate.go

@@ -263,10 +263,11 @@ func MigrateToGws() {
 		return
 	}
 	for _, node := range nodes {
-		if node.IsIngressGateway || node.IsRelay || node.IsInternetGateway {
+		if node.IsIngressGateway || node.IsRelay || node.IsInternetGateway || node.IsFailOver {
 			node.IsGw = true
 			node.IsIngressGateway = true
 			node.IsRelay = true
+			node.IsAutoRelay = true
 			if node.Tags == nil {
 				node.Tags = make(map[models.TagID]struct{})
 			}