Ver Fonte

Merge branch 'develop' of https://github.com/gravitl/netmaker into NM-79

abhishek9686 há 1 mês atrás
pai
commit
908e4e7a77

+ 11 - 0
controllers/ext_client.go

@@ -726,6 +726,16 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 		for _, extclient := range extclients {
+			// if device id is sent, then make sure extclient with the same device id
+			// does not exist.
+			if customExtClient.DeviceID != "" && extclient.DeviceID == customExtClient.DeviceID &&
+				extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
+				err = errors.New("remote client config already exists on the gateway")
+				slog.Error("failed to create extclient", "user", userName, "error", err)
+				logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+				return
+			}
+
 			if extclient.RemoteAccessClientID != "" &&
 				extclient.RemoteAccessClientID == customExtClient.RemoteAccessClientID && extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
 				// extclient on the gw already exists for the remote access client
@@ -774,6 +784,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		extclient.Enabled = parentNetwork.DefaultACL == "yes"
 	}
 	extclient.Os = customExtClient.Os
+	extclient.DeviceID = customExtClient.DeviceID
 	extclient.DeviceName = customExtClient.DeviceName
 	if customExtClient.IsAlreadyConnectedToInetGw {
 		slog.Warn("RAC/Client is already connected to internet gateway. this may mask their real IP address", "client IP", customExtClient.PublicEndpoint)

+ 1 - 1
controllers/node.go

@@ -417,7 +417,7 @@ func getNetworkNodeStatus(w http.ResponseWriter, r *http.Request) {
 
 	}
 	nodes = logic.AddStaticNodestoList(nodes)
-	nodes = logic.AddStatusToNodes(nodes, false)
+	nodes = logic.AddStatusToNodes(nodes, true)
 	// return all the nodes in JSON/API format
 	apiNodesStatusMap := logic.GetNodesStatusAPI(nodes[:])
 	logger.Log(3, r.Header.Get("user"), "fetched all nodes they have access to")

+ 4 - 4
controllers/server.go

@@ -1,8 +1,11 @@
 package controller
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"github.com/google/go-cmp/cmp"
 	"net/http"
 	"os"
@@ -111,10 +114,7 @@ func getUsage(w http.ResponseWriter, _ *http.Request) {
 	if err == nil {
 		serverUsage.Ingresses = len(ingresses)
 	}
-	egresses, err := logic.GetAllEgresses()
-	if err == nil {
-		serverUsage.Egresses = len(egresses)
-	}
+	serverUsage.Egresses, _ = (&schema.Egress{}).Count(db.WithContext(context.TODO()))
 	relays, err := logic.GetRelays()
 	if err == nil {
 		serverUsage.Relays = len(relays)

+ 3 - 0
logic/extpeers.go

@@ -433,6 +433,9 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode
 	if update.Country != "" && update.Country != old.Country {
 		new.Country = update.Country
 	}
+	if update.DeviceID != "" && old.DeviceID == "" {
+		new.DeviceID = update.DeviceID
+	}
 	return new
 }
 

+ 1 - 1
logic/hosts.go

@@ -126,7 +126,7 @@ func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) {
 
 		nodes := GetHostNodes(&host)
 		for _, node := range nodes {
-			GetNodeCheckInStatus(&node, false)
+			getNodeCheckInStatus(&node, false)
 			if node.Status == status {
 				validHosts = append(validHosts, host)
 				break

+ 1 - 1
logic/nodes.go

@@ -471,7 +471,7 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m
 		if statusCall {
 			GetNodeStatus(&node, aclDefaultPolicyStatusMap[node.Network])
 		} else {
-			GetNodeCheckInStatus(&node, true)
+			getNodeCheckInStatus(&node, true)
 		}
 
 		nodesWithStatus = append(nodesWithStatus, node)

+ 2 - 2
logic/status.go

@@ -6,9 +6,9 @@ import (
 	"github.com/gravitl/netmaker/models"
 )
 
-var GetNodeStatus = GetNodeCheckInStatus
+var GetNodeStatus = getNodeCheckInStatus
 
-func GetNodeCheckInStatus(node *models.Node, t bool) {
+func getNodeCheckInStatus(node *models.Node, t bool) {
 	// On CE check only last check-in time
 	if node.IsStatic {
 		if !node.StaticNode.Enabled {

+ 2 - 0
models/extclient.go

@@ -24,6 +24,7 @@ type ExtClient struct {
 	PostDown               string              `json:"postdown" bson:"postdown"`
 	Tags                   map[TagID]struct{}  `json:"tags"`
 	Os                     string              `json:"os"`
+	DeviceID               string              `json:"device_id"`
 	DeviceName             string              `json:"device_name"`
 	PublicEndpoint         string              `json:"public_endpoint"`
 	Country                string              `json:"country"`
@@ -44,6 +45,7 @@ type CustomExtClient struct {
 	PostDown                   string              `json:"postdown" bson:"postdown" validate:"max=1024"`
 	Tags                       map[TagID]struct{}  `json:"tags"`
 	Os                         string              `json:"os"`
+	DeviceID                   string              `json:"device_id"`
 	DeviceName                 string              `json:"device_name"`
 	IsAlreadyConnectedToInetGw bool                `json:"is_already_connected_to_inet_gw"`
 	PublicEndpoint             string              `json:"public_endpoint"`

+ 79 - 41
pro/controllers/users.go

@@ -1382,6 +1382,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to fetch user %s, error: %v", username, err), "badrequest"))
 		return
 	}
+	deviceID := r.URL.Query().Get("device_id")
 	remoteAccessClientID := r.URL.Query().Get("remote_access_clientid")
 	var req models.UserRemoteGwsReq
 	if remoteAccessClientID == "" {
@@ -1407,58 +1408,95 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	userGwNodes := proLogic.GetUserRAGNodes(*user)
+
+	userExtClients := make(map[string][]models.ExtClient)
+
+	// group all extclients of the requesting user by ingress
+	// gateway.
 	for _, extClient := range allextClients {
-		node, ok := userGwNodes[extClient.IngressGatewayID]
+		// filter our extclients that don't belong to this user.
+		if extClient.OwnerID != username {
+			continue
+		}
+
+		_, ok := userExtClients[extClient.IngressGatewayID]
+		if !ok {
+			userExtClients[extClient.IngressGatewayID] = []models.ExtClient{}
+		}
+
+		userExtClients[extClient.IngressGatewayID] = append(userExtClients[extClient.IngressGatewayID], extClient)
+	}
+
+	for ingressGatewayID, extClients := range userExtClients {
+		node, ok := userGwNodes[ingressGatewayID]
 		if !ok {
 			continue
 		}
-		if extClient.RemoteAccessClientID == req.RemoteAccessClientID && extClient.OwnerID == username {
 
-			host, err := logic.GetHost(node.HostID.String())
-			if err != nil {
-				continue
-			}
-			network, err := logic.GetNetwork(node.Network)
-			if err != nil {
-				slog.Error("failed to get node network", "error", err)
-				continue
-			}
-			nodesWithStatus := logic.AddStatusToNodes([]models.Node{node}, false)
-			if len(nodesWithStatus) > 0 {
-				node = nodesWithStatus[0]
+		var gwClient models.ExtClient
+		var found bool
+		if deviceID != "" {
+			for _, extClient := range extClients {
+				if extClient.DeviceID == deviceID {
+					gwClient = extClient
+					found = true
+					break
+				}
 			}
+		}
 
-			gws := userGws[node.Network]
-			if extClient.DNS == "" {
-				extClient.DNS = node.IngressDNS
+		if !found {
+			// TODO: prevent ip clashes.
+			if len(extClients) > 0 {
+				gwClient = extClients[0]
 			}
+		}
 
-			extClient.IngressGatewayEndpoint = utils.GetExtClientEndpoint(
-				host.EndpointIP,
-				host.EndpointIPv6,
-				logic.GetPeerListenPort(host),
-			)
-			extClient.AllowedIPs = logic.GetExtclientAllowedIPs(extClient)
-			gws = append(gws, models.UserRemoteGws{
-				GwID:              node.ID.String(),
-				GWName:            host.Name,
-				Network:           node.Network,
-				GwClient:          extClient,
-				Connected:         true,
-				IsInternetGateway: node.IsInternetGateway,
-				GwPeerPublicKey:   host.PublicKey.String(),
-				GwListenPort:      logic.GetPeerListenPort(host),
-				Metadata:          node.Metadata,
-				AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
-				NetworkAddresses:  []string{network.AddressRange, network.AddressRange6},
-				Status:            node.Status,
-				DnsAddress:        node.IngressDNS,
-				Addresses:         utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()),
-			})
-			userGws[node.Network] = gws
-			delete(userGwNodes, node.ID.String())
+		host, err := logic.GetHost(node.HostID.String())
+		if err != nil {
+			continue
+		}
+		network, err := logic.GetNetwork(node.Network)
+		if err != nil {
+			slog.Error("failed to get node network", "error", err)
+			continue
+		}
+		nodesWithStatus := logic.AddStatusToNodes([]models.Node{node}, false)
+		if len(nodesWithStatus) > 0 {
+			node = nodesWithStatus[0]
+		}
+
+		gws := userGws[node.Network]
+		if gwClient.DNS == "" {
+			gwClient.DNS = node.IngressDNS
 		}
+
+		gwClient.IngressGatewayEndpoint = utils.GetExtClientEndpoint(
+			host.EndpointIP,
+			host.EndpointIPv6,
+			logic.GetPeerListenPort(host),
+		)
+		gwClient.AllowedIPs = logic.GetExtclientAllowedIPs(gwClient)
+		gws = append(gws, models.UserRemoteGws{
+			GwID:              node.ID.String(),
+			GWName:            host.Name,
+			Network:           node.Network,
+			GwClient:          gwClient,
+			Connected:         true,
+			IsInternetGateway: node.IsInternetGateway,
+			GwPeerPublicKey:   host.PublicKey.String(),
+			GwListenPort:      logic.GetPeerListenPort(host),
+			Metadata:          node.Metadata,
+			AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
+			NetworkAddresses:  []string{network.AddressRange, network.AddressRange6},
+			Status:            node.Status,
+			DnsAddress:        node.IngressDNS,
+			Addresses:         utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()),
+		})
+		userGws[node.Network] = gws
+		delete(userGwNodes, node.ID.String())
 	}
+
 	// add remaining gw nodes to resp
 	for gwID := range userGwNodes {
 		node, err := logic.GetNodeByID(gwID)

+ 4 - 6
pro/util.go

@@ -4,10 +4,11 @@
 package pro
 
 import (
+	"context"
 	"encoding/base64"
-
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/models"
-
+	"github.com/gravitl/netmaker/schema"
 	"github.com/gravitl/netmaker/logic"
 )
 
@@ -49,10 +50,7 @@ func getCurrentServerUsage() (limits Usage) {
 	if err == nil {
 		limits.Ingresses = len(ingresses)
 	}
-	egresses, err := logic.GetAllEgresses()
-	if err == nil {
-		limits.Egresses = len(egresses)
-	}
+	limits.Egresses, _ = (&schema.Egress{}).Count(db.WithContext(context.TODO()))
 	relays, err := logic.GetRelays()
 	if err == nil {
 		limits.Relays = len(relays)

+ 6 - 0
schema/egress.go

@@ -64,6 +64,12 @@ func (e *Egress) ListByNetwork(ctx context.Context) (egs []Egress, err error) {
 	return
 }
 
+func (e *Egress) Count(ctx context.Context) (int, error) {
+	var count int64
+	err := db.FromContext(ctx).Model(&Egress{}).Count(&count).Error
+	return int(count), err
+}
+
 func (e *Egress) Delete(ctx context.Context) error {
 	return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Delete(&e).Error
 }