Browse Source

NET-940: Inet Gws (#2828)

* internet gws apis

* add validate check for inet request

* add default gw changes to peer update

* update json tag

* add OS checks for inet gws

* add set defaul gw pro func

* allow disable and enable inet gw

* add inet handlers to pro

* add fields to api node

* add inet allowed ips

* add default gw to pull

* unset node inet details on deletion

* unset internet gw on network nodes

* unset inet gw fix

* unset inet gw fix

* send default gw ip

* fix inet node endpoint

* add default gw endpoint ip to pull resp

* validate after unset gws

* add inet client peer allowedips to inet node

* validate after unset gws

* fix allowed ips for inet peer and gw node

* fix allowed ips for inet peer and gw node

* fix allowed ips for inet peer and gw node

* fix allowed ips for inet peer and gw node

* fix inet gw and relayed conflict

* fix inet gw and relayed conflict

* fix update req

* fix update inet gw api

* when inet gw is peer ignore other allowedIps

* test relay

* revert test relay

* revert inet peer update changes

* channel internet traffic of relayed node to relay's inetgw

* channel internet traffic of relayed node to relay's inetgw

* channel internet traffic of relayed node to relay's inetgw

* add check for relayed node

* add inet info to peer update

* add inet info to peer update

* fix update node to persist inet info

* fix go tests

* egress ranges with inet gw fix

* egress ranges with inet gw fix

* disallow node acting using inet gw to act as inet gw

* add check to validate inet gw

* fix typos

* add firewall check

* set inetgw on ingress req on community

* set inetgw to false on community on ingress del
Abhishek K 1 year ago
parent
commit
0638dcac49

+ 1 - 1
controllers/config/dnsconfig/Corefile

@@ -1,4 +1,4 @@
-skynet  {
+. {
     reload 15s
     reload 15s
     hosts /root/dnsconfig/netmaker.hosts {
     hosts /root/dnsconfig/netmaker.hosts {
 	fallthrough	
 	fallthrough	

+ 1 - 1
controllers/ext_client.go

@@ -222,7 +222,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 		gwendpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), host.ListenPort)
 		gwendpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), host.ListenPort)
 	}
 	}
 	var newAllowedIPs string
 	var newAllowedIPs string
-	if logic.IsInternetGw(gwnode) {
+	if logic.IsInternetGw(gwnode) || gwnode.InternetGwID != "" {
 		egressrange := "0.0.0.0/0"
 		egressrange := "0.0.0.0/0"
 		if gwnode.Address6.IP != nil && client.Address6 != "" {
 		if gwnode.Address6.IP != nil && client.Address6 != "" {
 			egressrange += "," + "::/0"
 			egressrange += "," + "::/0"

+ 3 - 0
controllers/hosts.go

@@ -141,6 +141,9 @@ func pull(w http.ResponseWriter, r *http.Request) {
 		HostNetworkInfo: hPU.HostNetworkInfo,
 		HostNetworkInfo: hPU.HostNetworkInfo,
 		EgressRoutes:    hPU.EgressRoutes,
 		EgressRoutes:    hPU.EgressRoutes,
 		FwUpdate:        hPU.FwUpdate,
 		FwUpdate:        hPU.FwUpdate,
+		ChangeDefaultGw: hPU.ChangeDefaultGw,
+		DefaultGwIp:     hPU.DefaultGwIp,
+		IsInternetGw:    hPU.IsInternetGw,
 	}
 	}
 
 
 	logger.Log(1, hostID, "completed a pull")
 	logger.Log(1, hostID, "completed a pull")

+ 21 - 26
controllers/node.go

@@ -345,7 +345,7 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 
 
-	node, err := validateParams(nodeid, params["network"])
+	node, err := logic.ValidateParams(nodeid, params["network"])
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
@@ -402,9 +402,9 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 	var gateway models.EgressGatewayRequest
 	var gateway models.EgressGatewayRequest
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
-	node, err := validateParams(params["nodeid"], params["network"])
+	node, err := logic.ValidateParams(params["nodeid"], params["network"])
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
@@ -453,9 +453,9 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 	netid := params["network"]
 	netid := params["network"]
-	node, err := validateParams(nodeid, netid)
+	node, err := logic.ValidateParams(nodeid, netid)
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	node, err = logic.DeleteEgressGateway(netid, nodeid)
 	node, err = logic.DeleteEgressGateway(netid, nodeid)
@@ -497,9 +497,9 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 	netid := params["network"]
 	netid := params["network"]
-	node, err := validateParams(nodeid, netid)
+	node, err := logic.ValidateParams(nodeid, netid)
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	var request models.IngressRequest
 	var request models.IngressRequest
@@ -540,9 +540,9 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 	netid := params["network"]
 	netid := params["network"]
-	node, err := validateParams(nodeid, netid)
+	node, err := logic.ValidateParams(nodeid, netid)
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	node, removedClients, err := logic.DeleteIngressGateway(nodeid)
 	node, removedClients, err := logic.DeleteIngressGateway(nodeid)
@@ -618,9 +618,9 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 
 
 	//start here
 	//start here
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
-	currentNode, err := validateParams(nodeid, params["network"])
+	currentNode, err := logic.ValidateParams(nodeid, params["network"])
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	var newData models.ApiNode
 	var newData models.ApiNode
@@ -636,6 +636,14 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 	newNode := newData.ConvertToServerNode(&currentNode)
 	newNode := newData.ConvertToServerNode(&currentNode)
+	if newNode.IsInternetGateway != currentNode.IsInternetGateway {
+		if newNode.IsInternetGateway {
+			logic.SetInternetGw(newNode, models.InetNodeReq{})
+		} else {
+			logic.UnsetInternetGw(newNode)
+		}
+
+	}
 	relayUpdate := logic.RelayUpdates(&currentNode, newNode)
 	relayUpdate := logic.RelayUpdates(&currentNode, newNode)
 	_, err = logic.GetHost(newNode.HostID.String())
 	_, err = logic.GetHost(newNode.HostID.String())
 	if err != nil {
 	if err != nil {
@@ -695,9 +703,9 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
 	// get params
 	// get params
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	var nodeid = params["nodeid"]
 	var nodeid = params["nodeid"]
-	node, err := validateParams(nodeid, params["network"])
+	node, err := logic.ValidateParams(nodeid, params["network"])
 	if err != nil {
 	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	forceDelete := r.URL.Query().Get("force") == "true"
 	forceDelete := r.URL.Query().Get("force") == "true"
@@ -716,16 +724,3 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
 	logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"])
 	logger.Log(1, r.Header.Get("user"), "Deleted node", nodeid, "from network", params["network"])
 	go mq.PublishMqUpdatesForDeletedNode(node, !fromNode, gwClients)
 	go mq.PublishMqUpdatesForDeletedNode(node, !fromNode, gwClients)
 }
 }
-
-func validateParams(nodeid, netid string) (models.Node, error) {
-	node, err := logic.GetNodeByID(nodeid)
-	if err != nil {
-		slog.Error("error fetching node", "node", nodeid, "error", err.Error())
-		return node, fmt.Errorf("error fetching node during parameter validation: %v", err)
-	}
-	if node.Network != netid {
-		slog.Error("network url param does not match node id", "url nodeid", netid, "node", node.Network)
-		return node, fmt.Errorf("network url param does not match node network")
-	}
-	return node, nil
-}

+ 1 - 5
controllers/node_test.go

@@ -132,11 +132,7 @@ func TestGetNetworkNodes(t *testing.T) {
 
 
 func TestValidateEgressGateway(t *testing.T) {
 func TestValidateEgressGateway(t *testing.T) {
 	var gateway models.EgressGatewayRequest
 	var gateway models.EgressGatewayRequest
-	t.Run("EmptyRange", func(t *testing.T) {
-		gateway.Ranges = []string{}
-		err := logic.ValidateEgressGateway(gateway)
-		assert.EqualError(t, err, "IP Ranges Cannot Be Empty")
-	})
+
 	t.Run("Success", func(t *testing.T) {
 	t.Run("Success", func(t *testing.T) {
 		gateway.Ranges = []string{"10.100.100.0/24"}
 		gateway.Ranges = []string{"10.100.100.0/24"}
 		err := logic.ValidateEgressGateway(gateway)
 		err := logic.ValidateEgressGateway(gateway)

+ 13 - 18
logic/gateway.go

@@ -7,17 +7,13 @@ import (
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
 )
 )
 
 
-var (
-	// SetInternetGw - sets the node as internet gw based on flag bool
-	SetInternetGw = func(node *models.Node, flag bool) {
-	}
-	// IsInternetGw - checks if node is acting as internet gw
-	IsInternetGw = func(node models.Node) bool {
-		return false
-	}
-)
+// IsInternetGw - checks if node is acting as internet gw
+func IsInternetGw(node models.Node) bool {
+	return node.IsInternetGateway
+}
 
 
 // GetInternetGateways - gets all the nodes that are internet gateways
 // GetInternetGateways - gets all the nodes that are internet gateways
 func GetInternetGateways() ([]models.Node, error) {
 func GetInternetGateways() ([]models.Node, error) {
@@ -27,13 +23,8 @@ func GetInternetGateways() ([]models.Node, error) {
 	}
 	}
 	igs := make([]models.Node, 0)
 	igs := make([]models.Node, 0)
 	for _, node := range nodes {
 	for _, node := range nodes {
-		if !node.IsEgressGateway {
-			continue
-		}
-		for _, ran := range node.EgressGatewayRanges {
-			if ran == "0.0.0.0/0" {
-				igs = append(igs, node)
-			}
+		if node.IsInternetGateway {
+			igs = append(igs, node)
 		}
 		}
 	}
 	}
 	return igs, nil
 	return igs, nil
@@ -167,7 +158,9 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
 		return models.Node{}, err
 		return models.Node{}, err
 	}
 	}
 	node.IsIngressGateway = true
 	node.IsIngressGateway = true
-	SetInternetGw(&node, ingress.IsInternetGateway)
+	if !servercfg.IsPro {
+		node.IsInternetGateway = ingress.IsInternetGateway
+	}
 	node.IngressGatewayRange = network.AddressRange
 	node.IngressGatewayRange = network.AddressRange
 	node.IngressGatewayRange6 = network.AddressRange6
 	node.IngressGatewayRange6 = network.AddressRange6
 	node.IngressDNS = ingress.ExtclientDNS
 	node.IngressDNS = ingress.ExtclientDNS
@@ -223,7 +216,9 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error
 	logger.Log(3, "deleting ingress gateway")
 	logger.Log(3, "deleting ingress gateway")
 	node.LastModified = time.Now()
 	node.LastModified = time.Now()
 	node.IsIngressGateway = false
 	node.IsIngressGateway = false
-	node.IsInternetGateway = false
+	if !servercfg.IsPro {
+		node.IsInternetGateway = false
+	}
 	node.IngressGatewayRange = ""
 	node.IngressGatewayRange = ""
 	node.Metadata = ""
 	node.Metadata = ""
 	err = UpsertNode(&node)
 	err = UpsertNode(&node)

+ 30 - 0
logic/nodes.go

@@ -218,6 +218,23 @@ func DeleteNode(node *models.Node, purge bool) error {
 		// unset all the relayed nodes
 		// unset all the relayed nodes
 		SetRelayedNodes(false, node.ID.String(), node.RelayedNodes)
 		SetRelayedNodes(false, node.ID.String(), node.RelayedNodes)
 	}
 	}
+	if node.InternetGwID != "" {
+		inetNode, err := GetNodeByID(node.InternetGwID)
+		if err == nil {
+			clientNodeIDs := []string{}
+			for _, inetNodeClientID := range inetNode.InetNodeReq.InetNodeClientIDs {
+				if inetNodeClientID == node.ID.String() {
+					continue
+				}
+				clientNodeIDs = append(clientNodeIDs, inetNodeClientID)
+			}
+			inetNode.InetNodeReq.InetNodeClientIDs = clientNodeIDs
+			UpsertNode(&inetNode)
+		}
+	}
+	if node.IsInternetGateway {
+		UnsetInternetGw(node)
+	}
 
 
 	if !purge && !alreadyDeleted {
 	if !purge && !alreadyDeleted {
 		newnode := *node
 		newnode := *node
@@ -598,3 +615,16 @@ func SortApiNodes(unsortedNodes []models.ApiNode) {
 		return unsortedNodes[i].ID < unsortedNodes[j].ID
 		return unsortedNodes[i].ID < unsortedNodes[j].ID
 	})
 	})
 }
 }
+
+func ValidateParams(nodeid, netid string) (models.Node, error) {
+	node, err := GetNodeByID(nodeid)
+	if err != nil {
+		slog.Error("error fetching node", "node", nodeid, "error", err.Error())
+		return node, fmt.Errorf("error fetching node during parameter validation: %v", err)
+	}
+	if node.Network != netid {
+		slog.Error("network url param does not match node id", "url nodeid", netid, "node", node.Network)
+		return node, fmt.Errorf("network url param does not match node network")
+	}
+	return node, nil
+}

+ 45 - 20
logic/peers.go

@@ -2,6 +2,7 @@ package logic
 
 
 import (
 import (
 	"errors"
 	"errors"
+	"fmt"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
 
 
@@ -28,10 +29,30 @@ var (
 	GetFailOverPeerIps = func(peer, node *models.Node) []net.IPNet {
 	GetFailOverPeerIps = func(peer, node *models.Node) []net.IPNet {
 		return []net.IPNet{}
 		return []net.IPNet{}
 	}
 	}
-
+	// CreateFailOver - creates failover in a network
 	CreateFailOver = func(node models.Node) error {
 	CreateFailOver = func(node models.Node) error {
 		return nil
 		return nil
 	}
 	}
+
+	// SetDefaulGw
+	SetDefaultGw = func(node models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate {
+		return peerUpdate
+	}
+	SetDefaultGwForRelayedUpdate = func(relayed, relay models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate {
+		return peerUpdate
+	}
+	// UnsetInternetGw
+	UnsetInternetGw = func(node *models.Node) {
+		node.IsInternetGateway = false
+	}
+	// SetInternetGw
+	SetInternetGw = func(node *models.Node, req models.InetNodeReq) {
+		node.IsInternetGateway = true
+	}
+	// GetAllowedIpForInetNodeClient
+	GetAllowedIpForInetNodeClient = func(node, peer *models.Node) []net.IPNet {
+		return []net.IPNet{}
+	}
 )
 )
 
 
 // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
 // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
@@ -122,7 +143,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 			}
 			}
 			continue
 			continue
 		}
 		}
-
+		hostPeerUpdate = SetDefaultGw(node, hostPeerUpdate)
+		if !hostPeerUpdate.IsInternetGw {
+			hostPeerUpdate.IsInternetGw = IsInternetGw(node)
+		}
 		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
 		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
 		for _, peer := range currentPeers {
 		for _, peer := range currentPeers {
 			peer := peer
 			peer := peer
@@ -164,6 +188,9 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
 					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
 					continue
 					continue
 				}
 				}
+				if node.IsRelayed && node.RelayedBy == peer.ID.String() {
+					hostPeerUpdate = SetDefaultGwForRelayedUpdate(node, peer, hostPeerUpdate)
+				}
 			}
 			}
 
 
 			uselocal := false
 			uselocal := false
@@ -251,18 +278,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 				logger.Log(1, "error retrieving external clients:", err.Error())
 				logger.Log(1, "error retrieving external clients:", err.Error())
 			}
 			}
 		}
 		}
-		addedInetGwRanges := false
 		if node.IsEgressGateway && node.EgressGatewayRequest.NatEnabled == "yes" && len(node.EgressGatewayRequest.Ranges) > 0 {
 		if node.IsEgressGateway && node.EgressGatewayRequest.NatEnabled == "yes" && len(node.EgressGatewayRequest.Ranges) > 0 {
 			hostPeerUpdate.FwUpdate.IsEgressGw = true
 			hostPeerUpdate.FwUpdate.IsEgressGw = true
-			if IsInternetGw(node) {
-				hostPeerUpdate.FwUpdate.IsEgressGw = true
-				egressrange := []string{"0.0.0.0/0"}
-				if node.Address6.IP != nil {
-					egressrange = append(egressrange, "::/0")
-				}
-				node.EgressGatewayRequest.Ranges = append(node.EgressGatewayRequest.Ranges, egressrange...)
-				addedInetGwRanges = true
-			}
 			hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()] = models.EgressInfo{
 			hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()] = models.EgressInfo{
 				EgressID: node.ID.String(),
 				EgressID: node.ID.String(),
 				Network:  node.PrimaryNetworkRange(),
 				Network:  node.PrimaryNetworkRange(),
@@ -274,21 +291,21 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 			}
 			}
 
 
 		}
 		}
-		if IsInternetGw(node) && !addedInetGwRanges {
+		if IsInternetGw(node) {
 			hostPeerUpdate.FwUpdate.IsEgressGw = true
 			hostPeerUpdate.FwUpdate.IsEgressGw = true
 			egressrange := []string{"0.0.0.0/0"}
 			egressrange := []string{"0.0.0.0/0"}
 			if node.Address6.IP != nil {
 			if node.Address6.IP != nil {
 				egressrange = append(egressrange, "::/0")
 				egressrange = append(egressrange, "::/0")
 			}
 			}
-			hostPeerUpdate.FwUpdate.EgressInfo[node.ID.String()] = models.EgressInfo{
-				EgressID: node.ID.String(),
+			hostPeerUpdate.FwUpdate.EgressInfo[fmt.Sprintf("%s-%s", node.ID.String(), "inet")] = models.EgressInfo{
+				EgressID: fmt.Sprintf("%s-%s", node.ID.String(), "inet"),
 				Network:  node.PrimaryAddressIPNet(),
 				Network:  node.PrimaryAddressIPNet(),
 				EgressGwAddr: net.IPNet{
 				EgressGwAddr: net.IPNet{
 					IP:   net.ParseIP(node.PrimaryAddress()),
 					IP:   net.ParseIP(node.PrimaryAddress()),
 					Mask: getCIDRMaskFromAddr(node.PrimaryAddress()),
 					Mask: getCIDRMaskFromAddr(node.PrimaryAddress()),
 				},
 				},
 				EgressGWCfg: models.EgressGatewayRequest{
 				EgressGWCfg: models.EgressGatewayRequest{
-					NodeID:     node.ID.String(),
+					NodeID:     fmt.Sprintf("%s-%s", node.ID.String(), "inet"),
 					NetID:      node.Network,
 					NetID:      node.Network,
 					NatEnabled: "yes",
 					NatEnabled: "yes",
 					Ranges:     egressrange,
 					Ranges:     egressrange,
@@ -354,7 +371,17 @@ func GetPeerListenPort(host *models.Host) int {
 // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
 // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings
 func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet {
 func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet {
 	var allowedips []net.IPNet
 	var allowedips []net.IPNet
-	allowedips = getNodeAllowedIPs(peer, node)
+	if peer.IsInternetGateway && node.InternetGwID == peer.ID.String() {
+		allowedips = append(allowedips, GetAllowedIpForInetNodeClient(node, peer)...)
+		return allowedips
+	}
+	if node.IsRelayed && node.RelayedBy == peer.ID.String() {
+		allowedips = append(allowedips, GetAllowedIpsForRelayed(node, peer)...)
+		if peer.InternetGwID != "" {
+			return allowedips
+		}
+	}
+	allowedips = append(allowedips, getNodeAllowedIPs(peer, node)...)
 
 
 	// handle ingress gateway peers
 	// handle ingress gateway peers
 	if peer.IsIngressGateway {
 	if peer.IsIngressGateway {
@@ -366,9 +393,7 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
 			allowedips = append(allowedips, extPeer.AllowedIPs...)
 			allowedips = append(allowedips, extPeer.AllowedIPs...)
 		}
 		}
 	}
 	}
-	if node.IsRelayed && node.RelayedBy == peer.ID.String() {
-		allowedips = append(allowedips, GetAllowedIpsForRelayed(node, peer)...)
-	}
+
 	return allowedips
 	return allowedips
 }
 }
 
 

+ 2 - 1
logic/relay.go

@@ -1,8 +1,9 @@
 package logic
 package logic
 
 
 import (
 import (
-	"github.com/gravitl/netmaker/models"
 	"net"
 	"net"
+
+	"github.com/gravitl/netmaker/models"
 )
 )
 
 
 var GetRelays = func() ([]models.Node, error) {
 var GetRelays = func() ([]models.Node, error) {

+ 11 - 9
models/api_node.go

@@ -28,21 +28,22 @@ type ApiNode struct {
 	RelayedNodes            []string `json:"relaynodes" yaml:"relayedNodes"`
 	RelayedNodes            []string `json:"relaynodes" yaml:"relayedNodes"`
 	IsEgressGateway         bool     `json:"isegressgateway"`
 	IsEgressGateway         bool     `json:"isegressgateway"`
 	IsIngressGateway        bool     `json:"isingressgateway"`
 	IsIngressGateway        bool     `json:"isingressgateway"`
-	IsInternetGateway       bool     `json:"isinternetgateway" yaml:"isinternetgateway"`
 	EgressGatewayRanges     []string `json:"egressgatewayranges"`
 	EgressGatewayRanges     []string `json:"egressgatewayranges"`
 	EgressGatewayNatEnabled bool     `json:"egressgatewaynatenabled"`
 	EgressGatewayNatEnabled bool     `json:"egressgatewaynatenabled"`
 	DNSOn                   bool     `json:"dnson"`
 	DNSOn                   bool     `json:"dnson"`
 	IngressDns              string   `json:"ingressdns"`
 	IngressDns              string   `json:"ingressdns"`
 	Server                  string   `json:"server"`
 	Server                  string   `json:"server"`
-	InternetGateway         string   `json:"internetgateway"`
 	Connected               bool     `json:"connected"`
 	Connected               bool     `json:"connected"`
 	PendingDelete           bool     `json:"pendingdelete"`
 	PendingDelete           bool     `json:"pendingdelete"`
 	Metadata                string   `json:"metadata" validate:"max=256"`
 	Metadata                string   `json:"metadata" validate:"max=256"`
 	// == PRO ==
 	// == PRO ==
-	DefaultACL    string              `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
-	IsFailOver    bool                `json:"is_fail_over"`
-	FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
-	FailedOverBy  uuid.UUID           `json:"failed_over_by" yaml:"failed_over_by"`
+	DefaultACL        string              `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
+	IsFailOver        bool                `json:"is_fail_over"`
+	FailOverPeers     map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
+	FailedOverBy      uuid.UUID           `json:"failed_over_by" yaml:"failed_over_by"`
+	IsInternetGateway bool                `json:"isinternetgateway" yaml:"isinternetgateway"`
+	InetNodeReq       InetNodeReq         `json:"inet_node_req" yaml:"inet_node_req"`
+	InternetGwID      string              `json:"internetgw_node_id" yaml:"internetgw_node_id"`
 }
 }
 
 
 // ApiNode.ConvertToServerNode - converts an api node to a server node
 // ApiNode.ConvertToServerNode - converts an api node to a server node
@@ -72,6 +73,8 @@ func (a *ApiNode) ConvertToServerNode(currentNode *Node) *Node {
 	convertedNode.IsInternetGateway = a.IsInternetGateway
 	convertedNode.IsInternetGateway = a.IsInternetGateway
 	convertedNode.EgressGatewayRequest = currentNode.EgressGatewayRequest
 	convertedNode.EgressGatewayRequest = currentNode.EgressGatewayRequest
 	convertedNode.EgressGatewayNatEnabled = currentNode.EgressGatewayNatEnabled
 	convertedNode.EgressGatewayNatEnabled = currentNode.EgressGatewayNatEnabled
+	convertedNode.InternetGwID = currentNode.InternetGwID
+	convertedNode.InetNodeReq = currentNode.InetNodeReq
 	convertedNode.RelayedNodes = a.RelayedNodes
 	convertedNode.RelayedNodes = a.RelayedNodes
 	convertedNode.DefaultACL = a.DefaultACL
 	convertedNode.DefaultACL = a.DefaultACL
 	convertedNode.OwnerID = currentNode.OwnerID
 	convertedNode.OwnerID = currentNode.OwnerID
@@ -150,13 +153,12 @@ func (nm *Node) ConvertToAPINode() *ApiNode {
 	apiNode.DNSOn = nm.DNSOn
 	apiNode.DNSOn = nm.DNSOn
 	apiNode.IngressDns = nm.IngressDNS
 	apiNode.IngressDns = nm.IngressDNS
 	apiNode.Server = nm.Server
 	apiNode.Server = nm.Server
-	if isEmptyAddr(apiNode.InternetGateway) {
-		apiNode.InternetGateway = ""
-	}
 	apiNode.Connected = nm.Connected
 	apiNode.Connected = nm.Connected
 	apiNode.PendingDelete = nm.PendingDelete
 	apiNode.PendingDelete = nm.PendingDelete
 	apiNode.DefaultACL = nm.DefaultACL
 	apiNode.DefaultACL = nm.DefaultACL
 	apiNode.IsInternetGateway = nm.IsInternetGateway
 	apiNode.IsInternetGateway = nm.IsInternetGateway
+	apiNode.InternetGwID = nm.InternetGwID
+	apiNode.InetNodeReq = nm.InetNodeReq
 	apiNode.IsFailOver = nm.IsFailOver
 	apiNode.IsFailOver = nm.IsFailOver
 	apiNode.FailOverPeers = nm.FailOverPeers
 	apiNode.FailOverPeers = nm.FailOverPeers
 	apiNode.FailedOverBy = nm.FailedOverBy
 	apiNode.FailedOverBy = nm.FailedOverBy

+ 3 - 0
models/mqtt.go

@@ -9,6 +9,9 @@ import (
 // HostPeerUpdate - struct for host peer updates
 // HostPeerUpdate - struct for host peer updates
 type HostPeerUpdate struct {
 type HostPeerUpdate struct {
 	Host            Host                 `json:"host" bson:"host" yaml:"host"`
 	Host            Host                 `json:"host" bson:"host" yaml:"host"`
+	ChangeDefaultGw bool                 `json:"change_default_gw"`
+	DefaultGwIp     net.IP               `json:"default_gw_ip"`
+	IsInternetGw    bool                 `json:"is_inet_gw"`
 	NodeAddrs       []net.IPNet          `json:"nodes_addrs" yaml:"nodes_addrs"`
 	NodeAddrs       []net.IPNet          `json:"nodes_addrs" yaml:"nodes_addrs"`
 	Server          string               `json:"server" bson:"server" yaml:"server"`
 	Server          string               `json:"server" bson:"server" yaml:"server"`
 	ServerVersion   string               `json:"serverversion" bson:"serverversion" yaml:"serverversion"`
 	ServerVersion   string               `json:"serverversion" bson:"serverversion" yaml:"serverversion"`

+ 8 - 6
models/node.go

@@ -66,7 +66,6 @@ type CommonNode struct {
 	IsEgressGateway     bool      `json:"isegressgateway" yaml:"isegressgateway"`
 	IsEgressGateway     bool      `json:"isegressgateway" yaml:"isegressgateway"`
 	EgressGatewayRanges []string  `json:"egressgatewayranges" bson:"egressgatewayranges" yaml:"egressgatewayranges"`
 	EgressGatewayRanges []string  `json:"egressgatewayranges" bson:"egressgatewayranges" yaml:"egressgatewayranges"`
 	IsIngressGateway    bool      `json:"isingressgateway" yaml:"isingressgateway"`
 	IsIngressGateway    bool      `json:"isingressgateway" yaml:"isingressgateway"`
-	IsInternetGateway   bool      `json:"isinternetgateway" yaml:"isinternetgateway"`
 	IsRelayed           bool      `json:"isrelayed" bson:"isrelayed" yaml:"isrelayed"`
 	IsRelayed           bool      `json:"isrelayed" bson:"isrelayed" yaml:"isrelayed"`
 	RelayedBy           string    `json:"relayedby" bson:"relayedby" yaml:"relayedby"`
 	RelayedBy           string    `json:"relayedby" bson:"relayedby" yaml:"relayedby"`
 	IsRelay             bool      `json:"isrelay" bson:"isrelay" yaml:"isrelay"`
 	IsRelay             bool      `json:"isrelay" bson:"isrelay" yaml:"isrelay"`
@@ -89,11 +88,14 @@ type Node struct {
 	IngressGatewayRange6    string               `json:"ingressgatewayrange6" bson:"ingressgatewayrange6" yaml:"ingressgatewayrange6"`
 	IngressGatewayRange6    string               `json:"ingressgatewayrange6" bson:"ingressgatewayrange6" yaml:"ingressgatewayrange6"`
 	Metadata                string               `json:"metadata"`
 	Metadata                string               `json:"metadata"`
 	// == PRO ==
 	// == PRO ==
-	DefaultACL    string              `json:"defaultacl,omitempty" bson:"defaultacl,omitempty" yaml:"defaultacl,omitempty" validate:"checkyesornoorunset"`
-	OwnerID       string              `json:"ownerid,omitempty" bson:"ownerid,omitempty" yaml:"ownerid,omitempty"`
-	IsFailOver    bool                `json:"is_fail_over" yaml:"is_fail_over"`
-	FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
-	FailedOverBy  uuid.UUID           `json:"failed_over_by" yaml:"failed_over_by"`
+	DefaultACL        string              `json:"defaultacl,omitempty" bson:"defaultacl,omitempty" yaml:"defaultacl,omitempty" validate:"checkyesornoorunset"`
+	OwnerID           string              `json:"ownerid,omitempty" bson:"ownerid,omitempty" yaml:"ownerid,omitempty"`
+	IsFailOver        bool                `json:"is_fail_over" yaml:"is_fail_over"`
+	FailOverPeers     map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
+	FailedOverBy      uuid.UUID           `json:"failed_over_by" yaml:"failed_over_by"`
+	IsInternetGateway bool                `json:"isinternetgateway" yaml:"isinternetgateway"`
+	InetNodeReq       InetNodeReq         `json:"inet_node_req" yaml:"inet_node_req"`
+	InternetGwID      string              `json:"internetgw_node_id" yaml:"internetgw_node_id"`
 }
 }
 
 
 // LegacyNode - legacy struct for node model
 // LegacyNode - legacy struct for node model

+ 12 - 0
models/structs.go

@@ -1,6 +1,7 @@
 package models
 package models
 
 
 import (
 import (
+	"net"
 	"strings"
 	"strings"
 	"time"
 	"time"
 
 
@@ -196,6 +197,11 @@ type IngressRequest struct {
 	IsInternetGateway bool   `json:"is_internet_gw"`
 	IsInternetGateway bool   `json:"is_internet_gw"`
 }
 }
 
 
+// InetNodeReq - exit node request struct
+type InetNodeReq struct {
+	InetNodeClientIDs []string `json:"inet_node_client_ids"`
+}
+
 // ServerUpdateData - contains data to configure server
 // ServerUpdateData - contains data to configure server
 // and if it should set peers
 // and if it should set peers
 type ServerUpdateData struct {
 type ServerUpdateData struct {
@@ -234,6 +240,12 @@ type HostPull struct {
 	HostNetworkInfo HostInfoMap           `json:"host_network_info,omitempty"  yaml:"host_network_info,omitempty"`
 	HostNetworkInfo HostInfoMap           `json:"host_network_info,omitempty"  yaml:"host_network_info,omitempty"`
 	EgressRoutes    []EgressNetworkRoutes `json:"egress_network_routes"`
 	EgressRoutes    []EgressNetworkRoutes `json:"egress_network_routes"`
 	FwUpdate        FwUpdate              `json:"fw_update"`
 	FwUpdate        FwUpdate              `json:"fw_update"`
+	ChangeDefaultGw bool                  `json:"change_default_gw"`
+	DefaultGwIp     net.IP                `json:"default_gw_ip"`
+	IsInternetGw    bool                  `json:"is_inet_gw"`
+}
+
+type DefaultGwInfo struct {
 }
 }
 
 
 // NodeGet - struct for a single node get response
 // NodeGet - struct for a single node get response

+ 164 - 0
pro/controllers/inet_gws.go

@@ -0,0 +1,164 @@
+package controllers
+
+import (
+	"encoding/json"
+	"errors"
+	"net/http"
+
+	"github.com/gorilla/mux"
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
+	proLogic "github.com/gravitl/netmaker/pro/logic"
+)
+
+// InetHandlers - handlers for internet gw
+func InetHandlers(r *mux.Router) {
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(createInternetGw))).Methods(http.MethodPost)
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(updateInternetGw))).Methods(http.MethodPut)
+	r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(deleteInternetGw))).Methods(http.MethodDelete)
+}
+
+// swagger:route POST /api/nodes/{network}/{nodeid}/inet_gw nodes createInternetGw
+//
+// Create an inet node.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: nodeResponse
+func createInternetGw(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	w.Header().Set("Content-Type", "application/json")
+	nodeid := params["nodeid"]
+	netid := params["network"]
+	node, err := logic.ValidateParams(nodeid, netid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	if node.IsInternetGateway {
+		logic.ReturnSuccessResponse(w, r, "node is already acting as internet gateway")
+		return
+	}
+	var request models.InetNodeReq
+	err = json.NewDecoder(r.Body).Decode(&request)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	host, err := logic.GetHost(node.HostID.String())
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	if host.OS != models.OS_Types.Linux {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only linux nodes can be made internet gws"), "badrequest"))
+		return
+	}
+	err = proLogic.ValidateInetGwReq(node, request, false)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	proLogic.SetInternetGw(&node, request)
+	err = logic.UpsertNode(&node)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	apiNode := node.ConvertToAPINode()
+	logger.Log(1, r.Header.Get("user"), "created ingress gateway on node", nodeid, "on network", netid)
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(apiNode)
+	go mq.PublishPeerUpdate(false)
+}
+
+// swagger:route PUT /api/nodes/{network}/{nodeid}/inet_gw nodes updateInternetGw
+//
+// update an inet node.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: nodeResponse
+func updateInternetGw(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	w.Header().Set("Content-Type", "application/json")
+	nodeid := params["nodeid"]
+	netid := params["network"]
+	node, err := logic.ValidateParams(nodeid, netid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	var request models.InetNodeReq
+	err = json.NewDecoder(r.Body).Decode(&request)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	if !node.IsInternetGateway {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node is not a internet gw"), "badrequest"))
+		return
+	}
+	err = proLogic.ValidateInetGwReq(node, request, true)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	proLogic.UnsetInternetGw(&node)
+	proLogic.SetInternetGw(&node, request)
+	err = logic.UpsertNode(&node)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	apiNode := node.ConvertToAPINode()
+	logger.Log(1, r.Header.Get("user"), "created ingress gateway on node", nodeid, "on network", netid)
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(apiNode)
+	go mq.PublishPeerUpdate(false)
+}
+
+// swagger:route DELETE /api/nodes/{network}/{nodeid}/inet_gw nodes deleteInternetGw
+//
+// Delete an internet gw.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: nodeResponse
+func deleteInternetGw(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	w.Header().Set("Content-Type", "application/json")
+	nodeid := params["nodeid"]
+	netid := params["network"]
+	node, err := logic.ValidateParams(nodeid, netid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	proLogic.UnsetInternetGw(&node)
+	err = logic.UpsertNode(&node)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	apiNode := node.ConvertToAPINode()
+	logger.Log(1, r.Header.Get("user"), "created ingress gateway on node", nodeid, "on network", netid)
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(apiNode)
+	go mq.PublishPeerUpdate(false)
+}

+ 6 - 2
pro/initialize.go

@@ -31,6 +31,7 @@ func InitPro() {
 		proControllers.RelayHandlers,
 		proControllers.RelayHandlers,
 		proControllers.UserHandlers,
 		proControllers.UserHandlers,
 		proControllers.FailOverHandlers,
 		proControllers.FailOverHandlers,
+		proControllers.InetHandlers,
 	)
 	)
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
 		// == License Handling ==
 		// == License Handling ==
@@ -100,9 +101,12 @@ func InitPro() {
 	logic.UpdateRelayed = proLogic.UpdateRelayed
 	logic.UpdateRelayed = proLogic.UpdateRelayed
 	logic.SetRelayedNodes = proLogic.SetRelayedNodes
 	logic.SetRelayedNodes = proLogic.SetRelayedNodes
 	logic.RelayUpdates = proLogic.RelayUpdates
 	logic.RelayUpdates = proLogic.RelayUpdates
-	logic.IsInternetGw = proLogic.IsInternetGw
-	logic.SetInternetGw = proLogic.SetInternetGw
 	logic.GetTrialEndDate = getTrialEndDate
 	logic.GetTrialEndDate = getTrialEndDate
+	logic.SetDefaultGw = proLogic.SetDefaultGw
+	logic.SetDefaultGwForRelayedUpdate = proLogic.SetDefaultGwForRelayedUpdate
+	logic.UnsetInternetGw = proLogic.UnsetInternetGw
+	logic.SetInternetGw = proLogic.SetInternetGw
+	logic.GetAllowedIpForInetNodeClient = proLogic.GetAllowedIpForInetNodeClient
 	mq.UpdateMetrics = proLogic.MQUpdateMetrics
 	mq.UpdateMetrics = proLogic.MQUpdateMetrics
 	mq.UpdateMetricsFallBack = proLogic.MQUpdateMetricsFallBack
 	mq.UpdateMetricsFallBack = proLogic.MQUpdateMetricsFallBack
 }
 }

+ 121 - 7
pro/logic/nodes.go

@@ -1,24 +1,132 @@
 package logic
 package logic
 
 
 import (
 import (
-	celogic "github.com/gravitl/netmaker/logic"
+	"errors"
+	"fmt"
+	"net"
+
+	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
+	"golang.org/x/exp/slog"
 )
 )
 
 
-// IsInternetGw - checks if node is acting as internet gw
-func IsInternetGw(node models.Node) bool {
-	return node.IsInternetGateway
+func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool) error {
+	inetHost, err := logic.GetHost(inetNode.HostID.String())
+	if err != nil {
+		return err
+	}
+	if inetHost.FirewallInUse == models.FIREWALL_NONE {
+		return errors.New("iptables or nftables needs to be installed")
+	}
+	if inetNode.InternetGwID != "" {
+		return fmt.Errorf("node %s is using a internet gateway already", inetHost.Name)
+	}
+	if inetNode.IsRelayed {
+		return fmt.Errorf("node %s is being relayed", inetHost.Name)
+	}
+	for _, clientNodeID := range req.InetNodeClientIDs {
+		clientNode, err := logic.GetNodeByID(clientNodeID)
+		if err != nil {
+			return err
+		}
+		clientHost, err := logic.GetHost(clientNode.HostID.String())
+		if err != nil {
+			return err
+		}
+		if clientHost.OS != models.OS_Types.Linux && clientHost.OS != models.OS_Types.Windows {
+			return errors.New("can only attach linux or windows machine to a internet gateway")
+		}
+		if clientNode.IsInternetGateway {
+			return fmt.Errorf("node %s acting as internet gateway cannot use another internet gateway", clientHost.Name)
+		}
+		if update {
+			if clientNode.InternetGwID != "" && clientNode.InternetGwID != inetNode.ID.String() {
+				return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name)
+			}
+		} else {
+			if clientNode.InternetGwID != "" {
+				return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name)
+			}
+		}
+
+		if clientNode.IsRelayed {
+			return fmt.Errorf("node %s is being relayed", clientHost.Name)
+		}
+
+		for _, nodeID := range clientHost.Nodes {
+			node, err := logic.GetNodeByID(nodeID)
+			if err != nil {
+				continue
+			}
+			if node.InternetGwID != "" && node.InternetGwID != inetNode.ID.String() {
+				return errors.New("nodes on same host cannot use different internet gateway")
+			}
+
+		}
+	}
+	return nil
 }
 }
 
 
 // SetInternetGw - sets the node as internet gw based on flag bool
 // SetInternetGw - sets the node as internet gw based on flag bool
-func SetInternetGw(node *models.Node, flag bool) {
-	node.IsInternetGateway = flag
+func SetInternetGw(node *models.Node, req models.InetNodeReq) {
+	node.IsInternetGateway = true
+	node.InetNodeReq = req
+	for _, clientNodeID := range req.InetNodeClientIDs {
+		clientNode, err := logic.GetNodeByID(clientNodeID)
+		if err != nil {
+			continue
+		}
+		clientNode.InternetGwID = node.ID.String()
+		logic.UpsertNode(&clientNode)
+	}
+
+}
+
+func UnsetInternetGw(node *models.Node) {
+	nodes, err := logic.GetNetworkNodes(node.Network)
+	if err != nil {
+		slog.Error("failed to get network nodes", "network", node.Network, "error", err)
+		return
+	}
+	for _, clientNode := range nodes {
+		if node.ID.String() == clientNode.InternetGwID {
+			clientNode.InternetGwID = ""
+			logic.UpsertNode(&clientNode)
+		}
+
+	}
+	node.IsInternetGateway = false
+	node.InetNodeReq = models.InetNodeReq{}
+
+}
+
+func SetDefaultGwForRelayedUpdate(relayed, relay models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate {
+	if relay.InternetGwID != "" {
+		peerUpdate.ChangeDefaultGw = true
+		peerUpdate.DefaultGwIp = relay.Address.IP
+
+	}
+	return peerUpdate
+}
+
+func SetDefaultGw(node models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate {
+	if node.InternetGwID != "" {
+
+		inetNode, err := logic.GetNodeByID(node.InternetGwID)
+		if err != nil {
+			return peerUpdate
+		}
+		peerUpdate.ChangeDefaultGw = true
+		peerUpdate.DefaultGwIp = inetNode.Address.IP
+
+	}
+	return peerUpdate
 }
 }
 
 
 // GetNetworkIngresses - gets the gateways of a network
 // GetNetworkIngresses - gets the gateways of a network
 func GetNetworkIngresses(network string) ([]models.Node, error) {
 func GetNetworkIngresses(network string) ([]models.Node, error) {
 	var ingresses []models.Node
 	var ingresses []models.Node
-	netNodes, err := celogic.GetNetworkNodes(network)
+	netNodes, err := logic.GetNetworkNodes(network)
 	if err != nil {
 	if err != nil {
 		return []models.Node{}, err
 		return []models.Node{}, err
 	}
 	}
@@ -29,3 +137,9 @@ func GetNetworkIngresses(network string) ([]models.Node, error) {
 	}
 	}
 	return ingresses, nil
 	return ingresses, nil
 }
 }
+
+// GetAllowedIpsForInet - get inet cidr for node using a inet gw
+func GetAllowedIpForInetNodeClient(node, peer *models.Node) []net.IPNet {
+	_, ipnet, _ := net.ParseCIDR("0.0.0.0/0")
+	return []net.IPNet{*ipnet}
+}

+ 3 - 0
pro/logic/relays.go

@@ -204,6 +204,9 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe
 		logger.Log(0, "RelayedByRelay called with invalid parameters")
 		logger.Log(0, "RelayedByRelay called with invalid parameters")
 		return
 		return
 	}
 	}
+	if relay.InternetGwID != "" {
+		return GetAllowedIpForInetNodeClient(relayed, relay)
+	}
 	peers, err := logic.GetNetworkNodes(relay.Network)
 	peers, err := logic.GetNetworkNodes(relay.Network)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, "error getting network clients", err.Error())
 		logger.Log(0, "error getting network clients", err.Error())