Browse Source

Merge pull request #2899 from gravitl/NET-1146

NET-1146: add user id check on RAG config creation, track failover usage
Abhishek K 1 year ago
parent
commit
23359ae5ad
5 changed files with 27 additions and 2 deletions
  1. 2 2
      controllers/ext_client.go
  2. 5 0
      controllers/server.go
  3. 15 0
      logic/nodes.go
  4. 1 0
      pro/types.go
  5. 4 0
      pro/util.go

+ 2 - 2
controllers/ext_client.go

@@ -394,9 +394,9 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		}
 		}
 		for _, extclient := range extclients {
 		for _, extclient := range extclients {
 			if extclient.RemoteAccessClientID != "" &&
 			if extclient.RemoteAccessClientID != "" &&
-				extclient.RemoteAccessClientID == customExtClient.RemoteAccessClientID && nodeid == extclient.IngressGatewayID {
+				extclient.RemoteAccessClientID == customExtClient.RemoteAccessClientID && extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
 				// extclient on the gw already exists for the remote access client
 				// extclient on the gw already exists for the remote access client
-				err = errors.New("remote client config already exists on the gateway. it may have been created by another user with this same remote client machine")
+				err = errors.New("remote client config already exists on the gateway")
 				slog.Error("failed to create extclient", "user", userName, "error", err)
 				slog.Error("failed to create extclient", "user", userName, "error", err)
 				logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 				logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 				return
 				return

+ 5 - 0
controllers/server.go

@@ -55,6 +55,7 @@ func getUsage(w http.ResponseWriter, _ *http.Request) {
 		Egresses         int `json:"egresses"`
 		Egresses         int `json:"egresses"`
 		Relays           int `json:"relays"`
 		Relays           int `json:"relays"`
 		InternetGateways int `json:"internet_gateways"`
 		InternetGateways int `json:"internet_gateways"`
+		FailOvers        int `json:"fail_overs"`
 	}
 	}
 	var serverUsage usage
 	var serverUsage usage
 	hosts, err := logic.GetAllHosts()
 	hosts, err := logic.GetAllHosts()
@@ -90,6 +91,10 @@ func getUsage(w http.ResponseWriter, _ *http.Request) {
 	if err == nil {
 	if err == nil {
 		serverUsage.InternetGateways = len(gateways)
 		serverUsage.InternetGateways = len(gateways)
 	}
 	}
+	failOvers, err := logic.GetAllFailOvers()
+	if err == nil {
+		serverUsage.FailOvers = len(failOvers)
+	}
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	json.NewEncoder(w).Encode(models.SuccessResponse{
 	json.NewEncoder(w).Encode(models.SuccessResponse{
 		Code:     http.StatusOK,
 		Code:     http.StatusOK,

+ 15 - 0
logic/nodes.go

@@ -625,3 +625,18 @@ func ValidateParams(nodeid, netid string) (models.Node, error) {
 	}
 	}
 	return node, nil
 	return node, nil
 }
 }
+
+// GetAllFailOvers - gets all the nodes that are failovers
+func GetAllFailOvers() ([]models.Node, error) {
+	nodes, err := GetAllNodes()
+	if err != nil {
+		return nil, err
+	}
+	igs := make([]models.Node, 0)
+	for _, node := range nodes {
+		if node.IsFailOver {
+			igs = append(igs, node)
+		}
+	}
+	return igs, nil
+}

+ 1 - 0
pro/types.go

@@ -63,6 +63,7 @@ type Usage struct {
 	Egresses         int `json:"egresses"`
 	Egresses         int `json:"egresses"`
 	Relays           int `json:"relays"`
 	Relays           int `json:"relays"`
 	InternetGateways int `json:"internet_gateways"`
 	InternetGateways int `json:"internet_gateways"`
+	FailOvers        int `json:"fail_overs"`
 }
 }
 
 
 // Usage.SetDefaults - sets the default values for usage
 // Usage.SetDefaults - sets the default values for usage

+ 4 - 0
pro/util.go

@@ -59,5 +59,9 @@ func getCurrentServerUsage() (limits Usage) {
 	if err == nil {
 	if err == nil {
 		limits.InternetGateways = len(gateways)
 		limits.InternetGateways = len(gateways)
 	}
 	}
+	failovers, err := logic.GetAllFailOvers()
+	if err == nil {
+		limits.FailOvers = len(failovers)
+	}
 	return
 	return
 }
 }